Java Deep Learning Projects
上QQ阅读APP看书,第一时间看更新

Problem description

Before ping into the coding, let's see a short description of the problem. This paragraph is directly quoted from the Kaggle Titanic survival prediction page:

"The sinking of the RMS Titanic is one of the most infamous shipwrecks in history. On April 15, 1912, during her maiden voyage, the Titanic sank after colliding with an iceberg, killing 1502 out of 2224 passengers and crew. This sensational tragedy shocked the international community and led to better safety regulations for ships. One of the reasons that the shipwreck led to such loss of life was that there were not enough lifeboats for the passengers and crew. Although there was some element of luck involved in surviving the sinking, some groups of people were more likely to survive than others, such as women, children, and the upper class. In this challenge, we ask you to complete the analysis of what sorts of people were likely to survive. In particular, we ask you to apply the tools of machine learning to predict which passengers survived the tragedy."

Now, before going even deeper, we need to know about the data of the passengers traveling on the Titanic during the disaster so that we can develop a predictive model that can be used for survival analysis. The dataset can be downloaded from https://github.com/rezacsedu/TitanicSurvivalPredictionDataset. There are two .csv files:

  • The training set (train.csv): Can be used to build your ML models. This file also includes labels as the ground truth for each passenger for the training set.
  • The test set (test.csv): Can be used to see how well your model performs on unseen data. However, for the test set, we do not provide the ground truth for each passenger.

In short, for each passenger in the test set, we have to use the trained model to predict whether they'll survive the sinking of the Titanic. Table 1 shows the metadata of the training set:

Now the question would be: using this labeled data, can we draw some straightforward conclusions? Say that being a woman, being in first class, and being a child were all factors that could boost a passenger's chances of survival during this disaster.

To solve this problem, we can start from the basic MLP, which is one of the oldest deep learning algorithms. For this, we use the Spark-based MultilayerPerceptronClassifier. At this point, you might be wondering why I am talking about Spark since it is not a DL library. However, Spark has an MLP implementation, which would be enough to serve our objective.

Then from the next chapter, we'll gradually start using more robust DNN by using DeepLearning4J, a JVM-based framework for developing deep learning applications. So let's see how to configure our Spark environment.