# https://github.com/matplotlib/matplotlib/issues/5836#issuecomment-179592427
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    import matplotlib.pyplot as plt

タイタニック号 災害から学ぶ機械学習#

データの取得#

Titanicのデータを取得します。

# !kaggle competitions download -c titanic --quiet

ダウンロードの際にはkaggle.jsonファイルを ~/.kaggle/kaggle.json に配置しておく必要があります。 titanic.zipファイルがダウンロードされますのでカレントディレクトリに展開します。

# !unzip titanic.zip

train.csv ファイルに訓練用のデータが保存されていますので pandas.read_csv で読み込みます。

# import pandas as pd

# train = pd.read_csv("train.csv")
# train.head(3)

test.csv ファイルにテスト用のデータが保存されていますので同様に pandas.read_csv で読み込みます。

# import pandas as pd

# test = pd.read_csv("test.csv")
# test.head(3)

また、 seaborn というライブラリを使用してもデータを取得することができます。

# see https://seaborn.pydata.org/examples/logistic_regression.html
import seaborn as sns

# Load the example Titanic dataset
df = sns.load_dataset("titanic")
# df = train

# Make a custom palette with gendered colors
pal = dict(male="#6495ED", female="#F08080")

# Show the survival probability as a function of age and sex
# g = sns.lmplot(
#     x="age",
#     y="survived",
#     col="sex",
#     hue="sex",
#     data=df,
#     palette=pal,
#     y_jitter=0.02,
#     logistic=True,
#     truncate=False,
# )
g = sns.lmplot(
    x="Age",
    y="Survived",
    col="Sex",
    hue="Sex",
    data=df,
    palette=pal,
    y_jitter=0.02,
    logistic=True,
    truncate=False,
)
g.set(xlim=(0, 80), ylim=(-0.05, 1.05))
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[6], line 23
      9 pal = dict(male="#6495ED", female="#F08080")
     11 # Show the survival probability as a function of age and sex
     12 # g = sns.lmplot(
     13 #     x="age",
   (...)
     21 #     truncate=False,
     22 # )
---> 23 g = sns.lmplot(
     24     x="Age",
     25     y="Survived",
     26     col="Sex",
     27     hue="Sex",
     28     data=df,
     29     palette=pal,
     30     y_jitter=0.02,
     31     logistic=True,
     32     truncate=False,
     33 )
     34 g.set(xlim=(0, 80), ylim=(-0.05, 1.05))

File ~/checkouts/readthedocs.org/user_builds/tkoyama010-notebooks/conda/latest/lib/python3.11/site-packages/seaborn/regression.py:595, in lmplot(data, x, y, hue, col, row, palette, col_wrap, height, aspect, markers, sharex, sharey, hue_order, col_order, row_order, legend, legend_out, x_estimator, x_bins, x_ci, scatter, fit_reg, ci, n_boot, units, seed, order, logistic, lowess, robust, logx, x_partial, y_partial, truncate, x_jitter, y_jitter, scatter_kws, line_kws, facet_kws)
    593 need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
    594 cols = np.unique([a for a in need_cols if a is not None]).tolist()
--> 595 data = data[cols]
    597 # Initialize the grid
    598 facets = FacetGrid(
    599     data, row=row, col=col, hue=hue,
    600     palette=palette,
   (...)
    603     **facet_kws,
    604 )

File ~/checkouts/readthedocs.org/user_builds/tkoyama010-notebooks/conda/latest/lib/python3.11/site-packages/pandas/core/frame.py:3767, in DataFrame.__getitem__(self, key)
   3765     if is_iterator(key):
   3766         key = list(key)
-> 3767     indexer = self.columns._get_indexer_strict(key, "columns")[1]
   3769 # take() does not accept boolean indexers
   3770 if getattr(indexer, "dtype", None) == bool:

File ~/checkouts/readthedocs.org/user_builds/tkoyama010-notebooks/conda/latest/lib/python3.11/site-packages/pandas/core/indexes/base.py:5877, in Index._get_indexer_strict(self, key, axis_name)
   5874 else:
   5875     keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)
-> 5877 self._raise_if_missing(keyarr, indexer, axis_name)
   5879 keyarr = self.take(indexer)
   5880 if isinstance(key, Index):
   5881     # GH 42790 - Preserve name from an Index

File ~/checkouts/readthedocs.org/user_builds/tkoyama010-notebooks/conda/latest/lib/python3.11/site-packages/pandas/core/indexes/base.py:5938, in Index._raise_if_missing(self, key, indexer, axis_name)
   5936     if use_interval_msg:
   5937         key = list(key)
-> 5938     raise KeyError(f"None of [{key}] are in the [{axis_name}]")
   5940 not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
   5941 raise KeyError(f"{not_found} not in index")

KeyError: "None of [Index(['Age', 'Sex', 'Survived'], dtype='object')] are in the [columns]"