This repository was archived by the owner on Aug 28, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathtutorial.py
More file actions
110 lines (89 loc) · 2.55 KB
/
tutorial.py
File metadata and controls
110 lines (89 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from flash import Trainer
from flash.tabular import TabularClassificationData, TabularClassifier
# %% [markdown]
# ## 1. Create the DataModule
#
# ### Variable & Definition
#
# - survival: Survival (0 = No, 1 = Yes)
# - pclass: Ticket class (1 = 1st, 2 = 2nd, 3 = 3rd)
# - sex: Sex
# - Age: Age in years
# - sibsp: number of siblings / spouses aboard the Titanic
# - parch: number of parents / children aboard the Titanic
# - ticket: Ticket number
# - fare: Passenger fare
# - cabin: Cabin number
# - embarked: Port of Embarkation
# %%
data_path = os.environ.get("PATH_DATASETS", "_datasets")
path_titanic = os.path.join(data_path, "titanic")
csv_train = os.path.join(path_titanic, "train.csv")
csv_test = os.path.join(path_titanic, "test.csv")
df_train = pd.read_csv(csv_train)
df_train["Survived"].hist(bins=2)
# %%
datamodule = TabularClassificationData.from_csv(
categorical_fields=["Sex", "Embarked", "Cabin"],
numerical_fields=["Fare", "Age", "Pclass", "SibSp", "Parch"],
target_fields="Survived",
train_file=csv_train,
val_split=0.1,
batch_size=32,
)
# %% [markdown]
# ## 2. Build the task
# %%
model = TabularClassifier.from_data(
datamodule,
learning_rate=0.1,
optimizer="AdamW",
n_a=8,
gamma=0.3,
)
# %% [markdown]
# ## 3. Create the trainer and train the model
# %%
from pytorch_lightning.loggers import CSVLogger # noqa: E402]
logger = CSVLogger(save_dir="logs/")
trainer = Trainer(
max_epochs=10,
gpus=torch.cuda.device_count(),
logger=logger,
accumulate_grad_batches=12,
gradient_clip_val=0.1,
)
# %%
trainer.fit(model, datamodule=datamodule)
# %%
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
metrics.set_index("step", inplace=True)
del metrics["epoch"]
sns.relplot(data=metrics, kind="line")
plt.gca().set_ylim([0, 1.25])
plt.gcf().set_size_inches(10, 5)
plt.grid()
# %% [markdown]
# ## 4. Generate predictions from a CSV
# %%
df_test = pd.read_csv(csv_test)
dm = TabularClassificationData.from_data_frame(
predict_data_frame=df_test,
parameters=datamodule.parameters,
batch_size=datamodule.batch_size,
)
preds = trainer.predict(model, datamodule=dm, output="classes")
print(preds[0][:10])
# %%
import itertools # noqa: E402]
import numpy as np # noqa: E402]
predictions = list(itertools.chain(*preds))
# assert len(df_test) == len(predictions)
df_test["Survived"] = np.argmax(predictions, axis=-1)
df_test.set_index("PassengerId", inplace=True)
df_test["Survived"].hist(bins=5)