diff --git a/docs/source/qml_examples/examples.ipynb b/docs/source/qml_examples/examples.ipynb index 48181bb14..1f1ed88e8 100644 --- a/docs/source/qml_examples/examples.ipynb +++ b/docs/source/qml_examples/examples.ipynb @@ -179,6 +179,63 @@ "Note that ``mol.representation`` is just a 1D numpy array." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generating BoB representation and kernel with per-group kernel widths\n", + "\n", + "Usually, there is only a single hyperparamer for in kernel regression, the kernel width.\n", + "However, it may beneficial for the model to use different kernel widths for different parts of the representation, e.g. in BoB to have different sigma for CC bond than for HH bonds. Below is an example how to achieve this with QMLcode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qml.data import Compound\n", + "from qml.helpers import get_BoB_groups, compose_BoB_sigma_vector\n", + "from qml.kernels import gaussian_sigmas_kernel, laplacian_sigmas_kernel\n", + "\n", + "# Specify maximal number of atoms (per atomtype) in molecules in dataset,\n", + "# e.g. if there are two molecule CO2 and H2O, then asize would be\n", + "# {\"C\":1, \"H\":2, \"O\":2} as maximum number of carbons is 1 (in CO2),\n", + "# maximum number of hydrogens is 2 (in H20) and maximum number of oxygens\n", + "# is 2 (in CO2).\n", + "asize = {\"O\":5, \"F\":6, \"N\":7, \"C\":9, \"H\":20} # asize for QM9\n", + "asize = {k: asize[k] for k in sorted(asize, key=asize.get)} # Sorting\n", + "\n", + "# Assume the QM9 dataset is a list of Compound objects\n", + "for compound in qm9:\n", + " # Generate the BoB representation for each compound\n", + " compound.generate_bob(size=29, asize=asize)\n", + "\n", + "# Generate a vector of representations (feature vectors) of compounds\n", + "X = np.array([c.representation for c in compounds])\n", + "\n", + "# Bags of bonds are sorted in every feature vector,\n", + "# get initial and ending indices for every group\n", + "low_indices, high_indices = get_BoB_groups(asize)\n", + "\n", + "# Specify kernel widths for each group\n", + "sigmas_for_bags = {\"OO\":3797., \"OF\":7.69e5, \"ON\":117., \"OC\":337., \"OH\":26.,\n", + " \"FF\":1.88e7, \"FN\":5.78e6, \"FC\":291., \"FH\":2.83e5,\n", + " \"NN\":5.82e6, \"NC\":180., \"NH\":26.6,\n", + " \"CC\":69., \"CH\":8.92e6,\n", + " \"HH\":5.48}\n", + "\n", + "# Get per-feature vector of kernel widths\n", + "# Per-group would be enough, but per-feature vector is the same size\n", + "# as representation vector and handling it with Fortran is easier\n", + "sigmas_vector = compose_BoB_sigma_vector(sigmas_for_bags, low_indices, high_indices)\n", + "\n", + "# Generate a kernel with per-group specific kernel widths\n", + "K = gaussian_sigmas_kernel(X, X, sigmas_vector)\n", + "# K = laplacian_sigmas_kernel(X, X, sigmas_vector)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -458,17 +515,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "100 files were loaded.\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from qml.aglaia.aglaia import ARMP\n", "import matplotlib.pyplot as plt\n", @@ -494,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -511,17 +560,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(100, 19, 165)\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "estimator.generate_compounds(filenames)\n", "estimator.generate_representation(method=\"fortran\")\n", @@ -539,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -555,7 +596,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -573,27 +614,9 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The RMSE is 0.05667379826437678 kJ/mol\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEKCAYAAAA8QgPpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XucXVV99/HPl0mA4aIDEmiuTcA0NohcnEIQ64OiJKFIoi9pQ/UxYmoeKxSRNkAKT/GCF4wVpLXaKGhURCjGECgSI0JpbROZkEC4pYSAMAmSKCQoTCGXX//Y6yRnJmfOnDlzrjPf9+t1Xmfvtdc565cNM7/Za+29liICMzOz/tqn3gGYmVlzcgIxM7OyOIGYmVlZnEDMzKwsTiBmZlYWJxAzMyuLE4iZmZXFCcTMzMriBGJmZmUZVu8Aqumwww6L8ePH1zsMM7OmsmrVql9HxIi+6g3qBDJ+/Hg6OjrqHYaZWVOR9MtS6rkLy8zMyuIEYmZmZXECMTOzsjiBmJlZWZxAzMysLIP6Liwzs6FmyeqNLFi2jk1buxjV1sq8qZOYefzoqrTlBGJmNkgsWb2R+YvX0rV9JwAbt3Yxf/FagKokkap3YUm6XtJmSQ8VOPY3kkLSYWlfkq6VtF7Sg5JOyKs7W9Lj6TW72nGbmTWbBcvW7U4eOV3bd7Jg2bqqtFeLMZBvA9N6FkoaC7wLeDqveDowMb3mAl9LdQ8FrgBOAk4ErpB0SFWjNjNrMpu2dvWrfKCqnkAi4l7g+QKHrgYuBiKvbAbwncisANokjQSmAssj4vmIeAFYToGkZGY2lI1qa+1X+UDV5S4sSWcBGyPigR6HRgPP5O13prLeys3MLJk3dRKtw1u6lbUOb2He1ElVaa/mg+iSDgAuA04vdLhAWRQpL/T9c8m6vxg3blyZUZqZNZ/cQPlgvgvrKGAC8IAkgDHA/ZJOJLuyGJtXdwywKZWf2qP8nkJfHhELgYUA7e3tBZOMmdlgNfP40VVLGD3VvAsrItZGxOERMT4ixpMlhxMi4lfAUuCD6W6sKcC2iHgWWAacLumQNHh+eiozM7M6qcVtvDcC/wVMktQpaU6R6ncAG4D1wDeAjwFExPPAZ4D70uvTqczMzOpEEYO3l6e9vT28HoiZWf9IWhUR7X3V81xYZmZWFicQMzMrixOImZmVxQnEzMzK4gRiZmZlcQIxM7OyOIGYmVlZnEDMzKwsTiBmZlYWJxAzMyuLE4iZmZXFCcTMzMriBGJmZmVxAjEzs7I4gZiZWVmcQMzMrCy1WJHwekmbJT2UV7ZA0mOSHpT0I0ltecfmS1ovaZ2kqXnl01LZekmXVjtuMzMrrhZXIN8GpvUoWw68MSLeBPw3MB9A0mRgFnB0+sw/SWqR1AJ8FZgOTAbOSXXNzKxOqp5AIuJe4PkeZT+JiB1pdwUwJm3PAH4QEa9ExJNka6OfmF7rI2JDRLwK/CDVNTOzOmmEMZAPAz9O26OBZ/KOdaay3srNzKxO6ppAJF0G7ABuyBUVqBZFygt951xJHZI6tmzZUplAzcxsL3VLIJJmA2cC74+IXDLoBMbmVRsDbCpSvpeIWBgR7RHRPmLEiMoHbmZmQJ0SiKRpwCXAWRHxct6hpcAsSftJmgBMBH4B3AdMlDRB0r5kA+1Lax23mZntMazaDUi6ETgVOExSJ3AF2V1X+wHLJQGsiIiPRsTDkm4GHiHr2jovInam7zkfWAa0ANdHxMPVjt3MzHqnPb1Hg097e3t0dHTUOwwzs6YiaVVEtPdVr88rEEntwB8Do4Au4CHgpxHxfNEPmpnZoNbrGIikD0m6n6y7qRVYB2wG3krW9bRI0rjahGlmZo2m2BXIgcApEdFV6KCk48gGuZ+uRmBmZtbYek0gEfHVYh+MiDWVD8fMzJpFrwlE0rXFPhgRF1Q+HDMzaxbFurBW1SwKMzNrOsW6sBbl70s6OCuO31U9KjMza3h9Poku6Y2SVpPdvvuIpFWSjq5+aGZm1shKmcpkIXBRRPx+RIwD/hr4RnXDMjOzRldKAjkwIu7O7UTEPWS3+JqZ2RBWylxYGyT9f+C7af8DwJPVC8nMzJpBKVcgHwZGAIuBH6Xtc6sZlJmZNb4+r0Ai4gXAz3yYmVk3pU6m+LfA+Pz6EfGm6oVlZmaNrpQxkBuAecBaYFd1wzEzs2ZRSgLZEhFe/c/MzLopZRD9CknflHSOpPfmXqU2IOl6SZslPZRXdqik5ZIeT++HpHJJulbSekkPSjoh7zOzU/3H03rqZmZWR6UkkHOB44BpwLvT68x+tPHt9Nl8lwJ3RcRE4K60DzCdbIr4icBc4GuQJRyypXBPAk4kS2qH9CMGMzOrsFK6sI6NiGPKbSAi7pU0vkfxDLJ10gEWAfcAl6Ty70S2zu4KSW2SRqa6y3OrIEpaTpaUbiw3LjMzG5hSrkBWSJpc4XaPiIhnAdL74al8NPBMXr3OVNZbuZmZ1UkpVyBvBWZLehJ4BRDZrLzVuI1XBcqiSPneXyDNJev+Ytw4r7hrZlYtpSSQnuMXlfCcpJER8WzqotqcyjuBsXn1xgCbUvmpPcrvKfTFEbGQbAJI2tvbCyYZMzMbuF67sCR1SPoK8IfAcxHxy/zXANtdCuTupJoN3JpX/sF0N9YUYFvq4loGnC7pkDR4fnoqMzOzOil2BTKFrPtqGvApSb8h+6X944j471IbkHQj2dXDYZI6ye6m+gJws6Q5wNPA2an6HcAZwHrgZdKcWxHxvKTPAPelep/ODaibmVl9KLvhqYSKWVfTdLKEMhH4r4j4WBVjG7D29vbo6OiodxhmZk1F0qqIaO+rXiljIMDuu6WuB66XtA9w8gDiMzOzJtdrApF0G73c6UR2N9YTkp6OiGd6qWNmZoNYsSuQL/XxuaOBm/GViJnZkNRrAomIfwOQ9OaIWJV/TNK7I+JaSZ7S3cxsiCrlSfRvSNo9lYmkc4DLASLiL6oVmJmZNbZSBtHfB9wi6f1kt/V+kOw5DDMzG8JKWdJ2g6RZwBKy+ahOj4iuqkdmZmYNrdhdWGvpfhfWoUALsFKSl7Q1Mxviil2B9GfNDzMzG2KKJZDfRMTvin1Y0kF91TEzs8Gp2F1Yt0r6e0lvk3RgrlDSkZLmSFpGdWbqNTOzJlDsOZDTJJ0B/D/glDQL7g5gHfCvwOyI+FVtwjQzs0ZT9C6siLiDbIZcMzOzbkp5kNDMzGwvTiBmZlYWJxAzMytLnwlE0lGS9kvbp0q6QFJbJRqX9AlJD0t6SNKNkvaXNEHSSkmPS7pJ0r6p7n5pf306Pr4SMZiZWXlKuQL5IbBT0uuB64AJwPcH2rCk0cAFQHtEvJHsKfdZwFXA1RExEXgBmJM+Mgd4ISJeD1yd6pmZWZ2UkkB2RcQO4D3ANRHxCWBkhdofBrRKGgYcADwLvAO4JR1fBMxM2zPSPun4aZJUoTjMzKyfSkkg29MU7rOB21PZ8IE2HBEbyRateposcWwDVgFbU8IC6ARGp+3RZJM5ko5vA1430DjMzKw8pSSQc8lWHfxsRDwpaQLwvYE2nB5MnEHWJTYKOBCYXqBqbkLHQlcbey25K2mupA5JHVu2bBlomGZm1os+E0hEPAJcAtyf9p+MiC9UoO13Ak9GxJaI2A4sBt4CtKUuLYAxwKa03QmMBUjHXws8XyDehRHRHhHtI0aMqECYZmZWSCl3Yb0bWAPcmfaPk7S0Am0/DUyRdEAayzgNeAS4m2wRK8i6zW5N20vTPun4zyJirysQMzOrjVK6sD4JnAhsBYiINWTdTgMSESvJBsPvB9amWBaSXe1cJGk92RjHdekj1wGvS+UXAZcONAYzMytfKUva7oiIbT1ueKrIX/4RcQVwRY/iDWQJq2fd/wHOrkS7ZmY2cKUkkIck/TnQImki2bMb/1ndsMzMrNGV0oX1V8DRwCvAjcCLwIXVDMrMzBpfn1cgEfEycFl6mZmZAUUSiKRrIuJCSbdRYMwjIs6qamRmVtCS1RtZsGwdm7Z2MaqtlXlTJzHz+NF9f9CswopdgXw3vX+pFoGYWd+WrN7I/MVr6dq+E4CNW7uYv3gtgJOI1VyxJW1Xpfd/q104ZlbMgmXrdiePnK7tO1mwbJ0TiNVcn2MgktaydxfWNqADuDIiflONwMxsb5u2dvWr3KyaSrmN98fATvZM4T6LbF6qbcC3gXdXJTIz28uotlY2FkgWo9pa6xCNDXWl3MZ7SkTMj4i16XUZ8H8i4ipgfHXDM7N886ZOonV4S7ey1uEtzJs6qU4R2VBWSgI5SNJJuR1JJwIHpd0dhT9iZtUw8/jRfP69xzC6rRUBo9ta+fx7j/H4h9VFKV1Yc4BvScoljd8CcyQdCHy+apGZWUEzjx/thGENoWgCkbQPcGREHCPptYAiYmtelZurGp2ZmTWsol1YEbELOD9tb+uRPMzMbAgrZQxkuaS/kTRW0qG5V9UjMzOzhlbKGMiH0/t5eWUBHFn5cMzMrFmUMpnigBePMjOzwaeUJW0PkHS5pIVpf6KkMyvRuKQ2SbdIekzSo5JOTl1kyyU9nt4PSXUl6VpJ6yU9KOmESsRgZmblKWUM5FvAq8Bb0n4ncGWF2v8KcGdEvAE4FniUbKnauyJiInAXe5aunQ5MTK+5wNcqFIOZmZWhlARyVER8EdgOEBFdZFOZDIik1wBvI615HhGvpru8ZgCLUrVFwMy0PQP4TmRWAG2SRg40DjMzK08pCeRVSa2kCRUlHUW2OuFAHQlsIXtIcbWkb6aHE4+IiGcB0vvhqf5o4Jm8z3emMjMzq4NSEsgVwJ3AWEk3kHUrXVyBtocBJwBfi4jjgZfY011VSKGrnr0WupI0V1KHpI4tW7ZUIEwzMyukzwQSEcuB9wIfIlsTvT0i7qlA251AZ0SsTPu3kCWU53JdU+l9c179sXmfHwNsKhDvwohoj4j2ESNGVCBMMzMrpJQrEID9gReAF4HJkt420IYj4lfAM5Jy04ieBjwCLAVmp7LZwK1peynwwXQ31hRgW66ry8zMaq+UBaWuAv4MeBjYlYoDuLcC7f8VcIOkfYENwLlkSe1mSXOAp4GzU907gDOA9cDLqa6ZmdVJKU+izwQmRUQlBs67iYg1QHuBQ6cVqBt0fxrezMzqqJQurA3A8GoHYmZmzaWUK5CXgTWS7iLv9t2IuKBqUZmZWcMrJYEsTS8zM7Pdek0gkl4TES9GxKICx8ZVNywzM2t0xcZA7sltpO6rfEuqEo2ZmTWNYgkk/8nvngtIDXguLDMza27FEkj0sl1o38zMhphig+iHS7qI7Gojt03a9xwhZmZDXLEE8g3g4ALbAN+sWkRmZtYUek0gEfGpWgZiZmbNpdTJFM3MzLpxAjEzs7I4gZiZWVmKPYl+UW/HACLiy5UPx8zMmkWxu7Byd11NAv6IPfNhvZvKrAViZmZNrM+7sCT9BDghIn6b9j8J/EtNojMzs4ZVyhjIOODVvP1XgfGVCkBSi6TVkm5P+xMkrZT0uKSb0mqFSNov7a9PxysWg5mZ9V8pCeS7wC8kfVLSFcBK4DsVjOHjwKN5+1cBV0fERLJ12Oek8jnACxHxeuDqVM/MzOqkzwQSEZ8lW3/8BWArcG5EfK4SjUsaA/wJ6cl2SQLeAdySqiwiW1IXYEbaJx0/LdU3M7M6KPU23gOAFyPiK0CnpAkVav8a4GJgV9p/HbA1Inak/U5gdNoeDTwDkI5vS/XNzKwO+kwgqdvqEmB+KhoOfG+gDUs6E9gcEavyiwtUjRKO5X/vXEkdkjq2bNky0DDNzKwXpVyBvAc4C3gJICI20X1ixXKdApwl6SngB2RdV9cAbZJyd4eNATal7U5gLEA6/lrg+Z5fGhELI6I9ItpHjPCkwWZm1VJKAnk1IoL0176kAyvRcETMj4gxETEemAX8LCLeD9wNvC9Vmw3cmraXpn3S8Z+luMzMrA5KSSA3S/pnsiuDjwA/pbrTuV8CXCRpPdkYx3Wp/Drgdan8IuDSKsZgZmZ9UCl/xEt6F3A62TjEsohYXu3AKqG9vT06OjrqHYaZWVORtCoi2vuqV2wqk9wXXRURlwDLC5SZmdkQ1WcCAd5F1q2Ub3qBMrMBuXzJWm5c+Qw7I2iROOeksVw585h6h2VmvSg2G+9fAh8DjpL0YN6hg4H/rHZgNrRcvmQt31vx9O79nRG7951EzBpTsUH075PNvHtres+93pzuljKrmBtXPtOvcjOrv2Kz8W4Dtkn6CvB83my8B0s6KSJW1ipIG3yWrN7IgmXr2LS1i1Ftrezs5WaO3srNrP5KuY33a8Dv8vZfSmVmZVmyeiPzF69l49YuAti4tavXui2e7sysYZWSQJT/wF5E7KK0wXezghYsW0fX9p0l1T3npLFVjsbMylVKAtkg6QJJw9Pr48CGagdmg9emEq44WiQ+MGWcB9DNGlgpVxIfBa4FLiebzuQuYG41g7LBbVRba8Fuq9Ftrfz80nfUISIzK0cp64FsjohZEXF4RBwREX8eEZtrEZwNTvOmTqJ1eEu3stbhLcybOqlOEZlZOYo9B3JxRHxR0j9QYNr0iLigqpHZoDXz+GyJl/y7sOZNnbS73MyaQ7EurNwys55Myipu5vGjnTDMmlyx50BuS++LeqtjZmZDV7EurNso0HWVExFnVSUiMzNrCsW6sL6U3t8L/B57lrE9B3iqijGZmVkTKNaF9W8Akj4TEW/LO3SbpHurHpmZmTW0Uh4kHCHpyNyOpAnAgBcblzRW0t2SHpX0cHpAEUmHSlou6fH0fkgql6RrJa2X9KCkEwYag5mZla+UBPIJ4B5J90i6h2zN8gsr0PYO4K8j4g+BKcB5kiaTLVV7V0RMJHtoMbd07XRgYnrNxfNxmZnVVZ9PokfEnZImAm9IRY9FxCsDbTgingWeTdu/lfQoMBqYAZyaqi0C7iFbvGoG8J00L9cKSW2SRqbvMTOzGuvzCkTSAcA84PyIeAAYJ+nMSgYhaTxwPLASOCKXFNL74anaaCB/cYjOVGZmZnVQShfWt4BXgZPTfidwZaUCkHQQ8EPgwoh4sVjVAmV73WYsaa6kDkkdW7ZsqVSYZmbWQykJ5KiI+CKwHSAiuij8y7zfJA0nSx43RMTiVPycpJHp+EggN+9WJ5A/t/cYYFPP74yIhRHRHhHtI0YMeKzfzMx6UUoCeVVSK+mvfUlHAQMeA5Ek4Drg0Yj4ct6hpcDstD2bbEndXPkH091YU4BtHv8wM6ufUqZzvwK4Exgr6QbgFOBDFWj7FOD/AmslrUllfwt8AbhZ0hzgaeDsdOwO4AxgPfAycG4FYjAzszIVTSDpKuExsqfRp5B1XX08In490IYj4j/ovSvstAL1AzhvoO3aHj3XJfeMuGbWH0UTSESEpCUR8WbgX2sUk9VAbl3y3NKyG7d2MX/xWgAnETMrSSljICsk/VHVI7GaKrQuedf2nSxYtq5OEZlZsyllDOTtwEclPQW8RNbtFBHxpmoGZtXV27rkxdYrNzPLV0oCmV71KKzmeluXfFRbax2iMbNm1GsXlqT9JV1I9hT6NGBjRPwy96pZhFYVXpfczAaq2BXIIrKHB/+d7CpkMvDxWgRl1ed1yc1soIolkMkRcQyApOuAX9QmJKsVr0tuZgNRLIFsz21ExI7skRBrZH6uw8xqqVgCOVZSbnJDAa1pP3cX1muqHp2VZMnqjXzqtod54eXdOd/PdZhZ1RVb0ralt2PWON715Xt4fPNLBY/lnutwAjGzaijlNl5rQJcvWcv3VjzdZz0/12Fm1eIE0mSWrN7IRTetYVeJ9f1ch5lVixNIEyn1qiPHz3WYWTU5gTSBy5es5YaVTxN7rb/Yu7bW4XzyrKM9/mFmVeME0sCWrN7IZT9ay0uv7uy7cp6Jhx/I8otOrU5QZmaJE0iD6m93Vc4pRx3KDR85ue+KZmYD5ATSgN7/jf/i508836/P+KrDzGqt6RKIpGnAV4AW4JsR8YU6h1QxxZ7p6I2A908Zx5Uzj6lOUGZmvWiqBCKpBfgq8C6gE7hP0tKIeKS+kQ3MktUbufCmNX1XzHPA8H343Hvf5EFyM6ubpkogwInA+ojYACDpB8AMoGkTSDlXHR/wFYeZNYBmSyCjgWfy9juBk/IrSJoLzAUYN25c7SLrp3KuOvYRfPlPj/NVh5k1hGZLIIWmBO72dERELAQWArS3t/fjyYnaedMVd/LiK/27NXffFvHF9x3r5GFmDaPZEkgnMDZvfwywqU6x9Fs5d1eBb801s8bUbAnkPmCipAnARmAW8Of1Dak0J312Oc/99tV+fUbA1X/mLisza0xNlUDSwlbnA8vIbuO9PiIernNYRZV71eGBcjNrdE2VQAAi4g7gjnrHUYpyxjr2bxGPffaMKkVkZlY5TZdAmkG5Vx3DhJOHmTUNJ5AKKjdxgKciMbPm4wRSIeV0VwG8Zr8WHvzUtCpEZGZWXU4gFfCGy+7gf3b2/5ETX3WYWTNzAhmAcqdcd+Iws8HACaRM5cxhBXCNn+sws0HCCaQMly9Z2+/k4asOMxtsnEDKcOPKZ/qulAzbR3zpbM9hZWaDjxNICZas3siCZevYtLWLUW2t7IzSBsw9h5WZDWZOIH1Ysnoj8xevpWt7dovuxq1dfX7Gt+aa2VCwT70DaHQLlq3bnTxK8YEp45w8zGxI8BVIAZcvWcuNK5/ps6uqRWJnBC0S55w01pMfmtmQ4gTSQ6nPdoxua+Xnl76jBhGZmTUmd2H1UModVq3DW5g3dVINojEza1y+AumhWLeVgFFtrcybOsm35ZrZkOcE0kNuXKNQ+ROf91TrZmY5denCkrRA0mOSHpT0I0ltecfmS1ovaZ2kqXnl01LZekmXViu2c04a269yM7Ohql5jIMuBN0bEm4D/BuYDSJpMts750cA04J8ktUhqAb4KTAcmA+ekuhV35cxj+MCUcbRIQHbl4eVlzcz2VpcurIj4Sd7uCuB9aXsG8IOIeAV4UtJ64MR0bH1EbACQ9INU95FqxHflzGOcMMzM+tAId2F9GPhx2h4N5N8G1ZnKeivfi6S5kjokdWzZsqUK4ZqZGVTxCkTST4HfK3Dosoi4NdW5DNgB3JD7WIH6QeFEV/B2qYhYCCwEaG9v7/8qT2ZmVpKqJZCIeGex45JmA2cCp0Xsvu2pE8gfrR4DbErbvZWbmVkd1OsurGnAJcBZEfFy3qGlwCxJ+0maAEwEfgHcB0yUNEHSvmQD7UtrHbeZme1Rr+dA/hHYD1iu7G6nFRHx0Yh4WNLNZIPjO4DzImIngKTzgWVAC3B9RDxcn9DNzAxAUeLaFs1I0hbgl0WqHAb8ukbh9EejxgWNG1ujxgWNG1ujxgWOrRyVjOv3I2JEX5UGdQLpi6SOiGivdxw9NWpc0LixNWpc0LixNWpc4NjKUY+4GuE2XjMza0JOIGZmVpahnkAW1juAXjRqXNC4sTVqXNC4sTVqXODYylHzuIb0GIiZmZVvqF+BmJlZmYZEAmnk6eMLxFqXdlPbYyXdLelRSQ9L+ngqP1TSckmPp/dDUrkkXZtifVDSCVWOr0XSakm3p/0JklamuG5KD5mSHkS9KcW1UtL4KsfVJumW9P/Yo5JObqBz9on03/IhSTdK2r9e503S9ZI2S3oor6zf50nS7FT/8TSjRTXiaojfGYViyzv2N5JC0mFpv2bnbLeIGPQv4HRgWNq+CrgqbU8GHiB7qHEC8ATZg4otaftIYN9UZ3IN4qxLu3ntjwROSNsHk021Pxn4InBpKr807/ydQTYRpoApwMoqx3cR8H3g9rR/MzArbX8d+Mu0/THg62l7FnBTleNaBPxF2t4XaGuEc0Y24eiTQGve+fpQvc4b8DbgBOChvLJ+nSfgUGBDej8kbR9Shbga4ndGodhS+ViyB6t/CRxW63O2O45q/c/bqC/gPcANaXs+MD/v2DLg5PRallferV4VY6tLu0XiuRV4F7AOGJnKRgLr0vY/A+fk1d9drwqxjAHuAt4B3J5+SH6d90O++9zl/jum7WGpnqoU12vIfkmrR3kjnLPcLNaHpvNwOzC1nucNGE/3X9T9Ok/AOcA/55V3q1epuHocq+vvjEKxAbcAxwJPsSeB1PScRcTQ6MLqoaLTx1dYvdrdS+q+OB5YCRwREc8CpPfDU7VaxnsNcDGwK+2/DtgaETsKtL07rnR8W6pfDUcCW4Bvpe61b0o6kAY4ZxGxEfgS8DTwLNl5WEVjnLec/p6nevyMNNTvDElnARsj4oEeh2oe26BJIJJ+mvp5e75m5NUpdfr43sqrrV7tdg9COgj4IXBhRLxYrGqBsorHK+lMYHNErCqx7Vqex2FkXQxfi4jjgZfIumJ6U7PY0njCDLKullHAgWSrevbWfkP8/5c0xM9mo/3OkHQAcBnwd4UO9xJD1WKr12SKFReDY/r4YvHUhKThZMnjhohYnIqfkzQyIp6VNBLYnMprFe8pwFmSzgD2J+s2ugZokzQs/bWc33Yurk5Jw4DXAs9XIa5cW50RsTLt30KWQOp9zgDeCTwZEVsAJC0G3kJjnLec/p6nTuDUHuX3VCOwBv2dcRTZHwQPKJuIdgxwv6QTi8RWvXNWyf7NRn2Rra/+CDCiR/nRdB8Q20A2GDYsbU9gz4DY0TWIsy7t5rUv4DvANT3KF9B9oPOLaftP6D5o94saxHgqewbR/4Xug8EfS9vn0X0w+OYqx/TvwKS0/cl0vup+zoCTgIeBA1J7i4C/qud5Y+8xkH6dJ7LxnCfJBoMPSduHViGuhvmd0TO2HseeYs8YSE3PWcQQGUQH1pP1Aa5Jr6/nHbuM7O6JdcD0vPIzyO5CeoJsFcVaxVqXdlPbbyW7tH0w71ydQdYPfhfweHo/NNUX8NUU61qgvQYxnsqeBHIk2Xox69Mvxf1S+f5pf306fmSVYzoO6EjnbUn6IW2IcwZ8CngMeAj4bvrFV5fzBtxINhazneyv4jnlnCeyMYn16XVuleJqiN8ZhWLrcfwp9iSQmp2z3MtPopuZWVkGzSC6mZnVlhOImZmVxQnEzMzK4gRiZmZlcQIxM7OyOIFYw5L0Oklr0utXkjbm7e9bwXbeKWlb3nevkfT2Sn1/L222SPr3Cn3XP0qDQ68zAAAEUUlEQVR6S9ruzJ85NpUNk7S1R9lySb83wHZfL2lN2j5O0jcH8n3WfAbNk+g2+ETEb8iesUDSJ4HfRcSX8usoexxXEbFr72/ol7sjYuYAv6ObvKe99xIRO4E/rkAbI4DjI+L8fnzmQODgiPjVQNvPiYg1ko6SNDqyObhsCPAViDWd9JfvQ5K+DtwPjM3/C1vSrNxfw5KOkLRYUoekX0iaUkY71ylbU+PHkvZPxyZKWiZplaR7Jf1BKv+epL+XdDfwOUmHS7pL0v2S/ildRbX1vCqQdGmK70FJf5fKDk5tPpDieF+BMM9mz0R/+bEfIOknks4t8Jl3AD9L9TolfVbSCkn3STohfe4JSR9JdfaR9OUUw9pe4oBstt8/K/X8WvNzArFmNRm4LrIJDIv9xXst2fQY7cCfAr11s7y9RxfW+FQ+iWxql6OBLiB3lbKQbAqQN5NN3f2Ped91FNn8SRcDnwbujIgTgDvIJjXsJs3xNY5s6pHjgLekLqkzgKci4tiIeCOwvEDcp5DNsJvvYLJf5osi4lsFPjMduDNv/6mImAKsAK4jm778LcBn0vGzyc73sWTT+18t6XD21kEFrqqsebgLy5rVExFxXwn13glMShPPARwiqTUiunrU26sLS9LrgfURsTYVrQLGpzGGKcAP8743/2fpX/K61N4KfBYgIm6X9NsCMZ5O9kt9ddo/CPgDsqn0vyDpC8BtEfHzAp8dSTadfL7bgc9FxE0F6pNivyBvf2l6X0u2TshLwEuSdimbmfmtwPdTt9uvJP0H0E42bUe+zRRIkDZ4OYFYs3opb3sX3aes3j9vW8CJEfFqme28kre9k+xnRsCvI+K4EmIrNJV2TwKujIjr9jogtZNdiSyQdHtEfK5HlS66/3sBfg5Ml3Rz9JirSNIkshl688dmcv/GXXT/9+5iz7+3FPuneGyIcBeWNb301/4LaVxiH7IumJyfks0yC2R3C1WgvReAZyW9J33nPpKO7aX6f5B1neW6qg4uUGcZMCcNbiNpjKTDJI0mu3Hgu8CXydYd6elR4PU9yv6WLIldW6D+NAqMmfThXmBWunPsCLJus44C9f6AbNJGGyKcQGywuISsX/8usllLc84DTkmD048AH+nl8z3HQN7TS72cWcBHJT1ANmX6mb3UuwL4E0n3kw1eP0f3KxQi4g6ydURWSFpLtmb5QWRjDvelW2UvBnpefQD8K93Xesg5H3itpM+RXUXkriym0X38oxS3kM3o+wBZQr4oIjYXqPf2FI8NEZ6N16yK0l1bOyJih6S3kg3It1fw+0V2lTM9elk9UtKbgX8gS2D3RsSJlWo/r41W4G7glDRWYkOAE4hZFUl6A9maDi1kVwEfje5L81aijZOB30bEXt1Hks4juwq7ICJ+Wsl2e7QziWx983ur1YY1HicQMzMri8dAzMysLE4gZmZWFicQMzMrixOImZmVxQnEzMzK4gRiZmZl+V8IEgkW813TcQAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "score = estimator.score(idx_train)\n", "print(\"The RMSE is %s kJ/mol\" % (str(score)) )\n", @@ -617,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -646,27 +669,9 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The RMSE is 0.30888228285160746 kJ/mol\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEKCAYAAAAMzhLIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XuUXXV99/H3J8MgE0AnaKBkIE8gxriklESnEI3tEpFrBYKPFyhUij4iT6FIsREitMQKQomote1jGwqWakS5hDFQIAIVaalBEiZhCJiWm5AJQrgEkExhMvk+f+x9wpnJOXv2ZM515vNaa9ac/Tv7nP3NWTP5zv5dvj9FBGZmZuVMqHcAZmbW2JwozMwskxOFmZllcqIwM7NMThRmZpbJicLMzDI5UZiZWSYnCjMzy+REYWZmmXaqdwCV8I53vCOmTZtW7zDMzJrKqlWrno+IycOdNyYSxbRp01i5cmW9wzAzayqSfpXnPHc9mZlZJicKMzPL5ERhZmaZnCjMzCyTE4WZmWUaE7OezMzGm67uXhYtX8eGTX1MaW9j/pEzmTe7oyrXcqIwM2syXd29LFjaQ1//AAC9m/pYsLQHoCrJwl1PZmZNZtHydduSREFf/wCLlq+ryvWcKMzMmsyGTX0jah8tJwozsyYzpb1tRO2j5URhZtZk5h85k7bWlkFtba0tzD9yZlWu58FsM7MmUxiw9qwnMzMra97sjqolhqHc9WRmZpmcKMzMLJMThZmZZXKiMDOzTE4UZmaWyYnCzMwyOVGYmVkmJwozM8tUk0Qh6WpJz0l6qKhtD0l3SPrv9PuktF2Svi3pUUkPSnpvLWI0M7PSanVH8c/AUUPazgfuiogZwF3pMcDRwIz063TgOzWK0czMSqhJooiIe4AXhzQfD1yTPr4GmFfU/i+RWAG0S9q7FnGamdn26jlGsVdEPAOQft8zbe8Ani46b33aNoik0yWtlLRy48aNVQ/WzGy8asTBbJVoi+0aIhZHRGdEdE6ePLkGYZmZjU/1TBTPFrqU0u/Ppe3rgX2LztsH2FDj2MzMLFXPRLEMODV9fCrw46L2T6ezn+YALxe6qMzMrPZqsh+FpGuBDwHvkLQeuAi4DLhO0meBp4BPpKffChwDPApsBk6rRYxmZlZaTRJFRJxU5qnDSpwbwJnVjcjMzPJqxMFsMzNrIE4UZmaWyYnCzMwyOVGYmVkmJwozM8vkRGFmZpmcKMzMLJMThZmZZXKiMDOzTE4UZmaWyYnCzMwyDVvrSVIn8HvAFKAPeAi4MyKG7lhnZtZQurp7WbR8HRs29TGlvY35R85k3uzt9kGzYZRNFJL+GDgbeAJYBawDdgE+CJwn6SHgLyLiqRrEaWY2Ihd29bBkxVPbdj3r3dTHgqU9AE4WI5R1R7ErMDci+ko9KWkWMIOkRLiZWcPo6u4dlCQK+voHWLR8nRPFCJVNFBHx91kvjIjVlQ/HzGz0Fi1ft/3+yakNm0r+7WsZsrqevp31wog4u/LhmJmNTKlxiKxkMKW9rYbRjQ1ZXU+rahaFmdkO6OruZcHSHvr6B4A3xyHaJ7by0ub+7c4XMP/ImTWOsvlldT1dU3wsafekOX5T9ajMzIbR1d3LF69bw0AM7mTq6x/gLTtNoK21ZVsCgSRJnDxnqscndsCw6ygk/bakbpJpsQ9LWiXpgNFeWNJMSauLvl6RdI6khZJ6i9qPGe21zGxsKdxJDE0SBS/39XPpxw6ko70NAR3tbXzzU7O4eN6BtQ10jMizZ/Zi4NyI+CmApA8BVwIfGM2FI2IdMCt9zxagF7gJOA34ZkR8fTTvb2ZjT2E8oneYAekp7W3Mm93hu4cKyZModi0kCYCIuFvSrhWO4zDgsYj4laQKv7WZjQVDxyPKaWtt8ThEheVJFI9L+gvge+nxKSSL8CrpRODaouOzJH0aWAl8MSJeqvD1zKwJdHX3snDZWjb1bT8wXUqLxKUfO9B3EhWWp9bTZ4DJwFKSrqHJJN1DFSFpZ+A44Pq06TvAdJJuqWeAK8q87nRJKyWt3LhxY6XCMbMG0dXdy/zr1+ROEm2tLVzxyYOcJKpg2DuK9K/5aq6ZOBp4ICKeTa/3bOEJSVcCt5SJazHJ+AmdnZ3l1taYWZNatHwd/Vvz/Wp3uI5TVeUtCvhlYFrx+RHxOxWK4SSKup0k7R0Rz6SHJ5DMtjKzceDCrh5+cN9T5MwPtLW2uKupBvKMUSwB5gM9wNZKXlzSROBw4PNFzZendaQCeHLIc2Y2BnV19/KlG9bwxkD+zgHfRdROnkSxMSKWVePiEbEZePuQtj+qxrXMrDF1dfcy/4Y19OdMEq0tYtHHPRZRS3kSxUWS/gm4C3i90BgRS6sWlZmNaclU1wfp6x9ZJ8Wkia1cdOwBThI1lidRnAa8G2jlza6nIJkFZWY2Ihd29fD9FSPbnaCjvY17z/9wlSKy4eRJFAdFhNe9m9modXX3jjhJTMCF/OotzzqKFZLeU/VIzGzMW7R83YjOb50A3/jULHc11VmeO4oPAqdKeoJkjEIkVWQrNT3WzMaJvJsGtbe1svA4j0U0ijyJ4qiqR2FmY9LQTYXK7RNRcMqcqa7w2oCydrhbCdwL3AbcHRH/U7OozKzpldpUqHWCaJkgBkqsqJs7fQ8niQaVdUcxh6Tb6SjgK5JeAJYDt0XEf9UiODNrLsV3EBOk7faL6N8atLe1IrHtzsLdTI0va4e7LcDd6ReS9iapy3SxpBnAzyPiT2oQo5k1gQu7eliy4ikKqSFrU6EnLvuD2gVmo5ZnjAKAtP7S1cDVkiYA769aVGbWVE6+8ufc+9iLuc6d0t5W5Wis0rLGKG4Gyq2pfx14TNJTEfF0VSIzs4bW1d3LeTc+yOtb8q+u9qZCzSnrjiJrK9KdgAOA6/Cdhdm409Xdy7nXrc5V5bVFYmsEU1zEr2lljVH8DEDS+yJiVfFzko6NiG9L8loKs3FmJN1MAm8mNAbkGaO4UtKpEdEDIOkk4Bzg5oj4P1WNzswaxkgSxLbXzJnqJDEG5EkUHwdukHQyyXTZTwNHVDUqM2soO5IkvC5i7MizFerjkk4EuoCngSMiIt86fDNrahd29XDtfU+XnepajldYjy1Zs556GDzraQ+gBbhPUiW3QjWzBvTuC27lf0aw4xwkdxFLPuf5LWNN1h3FR2sWhZk1jK7uXv7sR6vLzo0vxQlibMtKFC9ExG+yXixpt+HOMbPmUajPNJIkMWPPXZ0kxrisRPFjSauBHwOrIuI1AEn7A4cCnwSuBG4YTQCSngReBQaALRHRKWkP4EfANOBJ4JMR8dJormNmpXV197Jw2Vo29ZWv6lrKzi3icu9dPS5kraM4TNIxwOeBuZImAVuAdcC/AqdGxK8rFMehEfF80fH5wF0RcZmk89Pj8yp0LTNLdXX3Mv/6NfTnWTmX8kD1+JM56ykibgVurVEsxY4HPpQ+voakMKEThVkF7cje1bu0yEliHMpdFLCKAviJpAD+MSIWA3ulRQiJiGck7VnXCM3GkB1ZEwGw1+47c98Fh1chImt0jZAo5kbEhjQZ3CHpl3leJOl04HSAqVOnVjM+szHj8G/czX8/91ru8ztcn8logEQRERvS789Jugk4GHhW0t7p3cTewHMlXrcYWAzQ2dk5ssneZuNMV3cvX7phDW/kXBfR2iIWeaDaUhOGO0HSdElvSR9/SNLZktorcXFJu0ravfCYpDTIQ8Ay4NT0tFNJZl6Z2Q64sKuHc360OneSmDSx1UnCBslzR3Ej0CnpncBVJP+J/wA4pgLX3wu4SVIhlh9ExO2S7geuk/RZ4CngExW4ltm4M9IBa89oslLyJIqtEbFF0gnAtyLibyV1V+LiEfE4cFCJ9heAwypxDbPxqqu7lyVOElYBeRJFf1pa/FTg2LSttXohmdmO2tEifh3tbU4SVlaeRHEacAZwSUQ8IWk/4PvVDcvMRqKru5dzf7Sa/JuSvsnbk9pw8pQZf1jSecDU9PgJ4LJqB2Zmw+vq7uUrN6/lpc0jK79R4OmvlsewiULSsST7Z+8M7CdpFvBXEXFctYMzs/Iu7OphyYqnRlTATyS7zrmbyUYiT9fTQpK1DXcDRMTqtPvJzOqkMFCdN0kImOK7B9tBeRLFloh4OZ3CWuAFbmZ10NXdy6Ll6+jdlH+TSc9mstHKkygekvSHQIukGcDZwH9WNywzG2pHuppm7Lmrk4SN2rArs4E/BQ4AXgeuBV4BzqlmUGY22Ei7miYouZO449wPVTMsGyfyzHraDFyQfplZjRS6mTZs6mOCNGyS8EC1VUvZRCHpWxFxjqSbKTEm4VlPZtUztBT4cAvoPM3VqinrjuJ76fev1yIQM0tc2NWTe78IAd/81CwnCKuqrK1QV6Xff1a7cMzs2vueznVeoavJScKqLc+Cux6273p6GVgJXJwW8DOzCsnqZmqR2BrhNRFWU3mmx94GDJCUFgc4keSPmZeBf+bNQoFmVgEtUtlkccUnvU+E1V6eRDE3IuYWHfdIujci5ko6pVqBmY1lxTOaht4dnHTIviX3kJg7fQ8nCauLPOsodpN0SOFA0sHAbunhlqpEZTaGdXX3smBpD72b+gigd1MfC5b20NXdC8DF8w7klDlTaUmrIbRInDJnKks+9/46Rm3jmWKYaXeSOoHv8mZyeBX4LPAw8AcRcV1VI8yhs7MzVq5cWe8wzHKZe9m/lSzB0dHexr3nf7gOEdl4JWlVRHQOd15m15OkCcD+EXGgpLeRJJZNRafUPUmYNZsNZeo0lWs3q7fMrqeI2AqclT5+eUiSMLMdMKW9bUTtZvWWZ4ziDkl/LmlfSXsUvkZ74fT9firpEUlrJX0hbV8oqVfS6vTrmNFey6yRzD9yJm2tLYPavMucNbI8s54+k34/s6gtgP1Hee0twBcj4gFJuwOrJN2RPvfNiPCKcBuTCjOXys16Mms0eYoCVmWTooh4BngmffyqpEcA/6ZY08qa8jrUvNkdTgzWNIbtepI0UdKFkhanxzMkfbSSQUiaBswG7kubzpL0oKSrJU2q5LXMqmG4Ka9mzSzPGMV3gTeAD6TH64GLKxWApN2AG4FzIuIV4DvAdGAWyR3HFWVed7qklZJWbty4sVLhmO2QRcvX0dc/MKitr3+ARcvX1Skis8rJkyimR8TlQD9ARPSRlPAYNUmtJEliSUQsTd//2YgYSGdcXUmyX/d2ImJxRHRGROfkyZMrEY7ZDvOUVxvL8gxmvyGpjbQwoKTpJLvdjYqSTbivAh6JiG8Ute+djl8AnAA8NNprmVVSV3cvC5etZVNfPwCTJrbytrbWbcfFPOXVxoI8ieIi4HZgX0lLgLnAH1fg2nOBPyKpHbU6bfsycJKkWSSJ6Ung8xW4lllFdHX3Mv/6NfRvfbOiwUub+2mZIFonaFC7p7zaWJFn1tMdkh4A5pB0OX0hIp4f7YUj4j8o3YV162jf26xaFi1fNygZFAxsDd46sZWJO+/kKa825uS5owDYBXgpPf89koiIe6oXllljyhpz2LS5n+6/PKKG0ZjVRp6Ni/4a+BSwFtiaNgfgRGHjzpT2tpIF/QrPmY1Fee4o5gEzI2LUA9hmzW7+kTO3G6MAaG2RxyNszMozPfZxoLXagZg1g3mzO1j0iYNob3vzV2LSxFYWfdw7z9nYleeOYjOwWtJdFE2LjYizqxaVWQNz+Q0bb/IkimXpl5mZjUNlE4Wkt0bEKxFxTYnnplY3LDMzaxRZYxR3Fx6k3U7FuqoSjZmZNZysRFG8GG7oRkUVqfVkZmaNLytRRJnHpY7NzGyMyhrM3lPSuSR3D4XHpMcu12pmNk5kJYorgd1LPAb4p6pFZGZmDaVsooiIr9QyEDMza0x5Vmabmdk45kRhZmaZnCjMzCxT1srsc8s9B1C8falZPXR197Jo+TpvFGRWZVmzngqznGYCv8ub9Z6OxXtRWJ0N3ZK0d1Mf869fA+BkYVZhw856kvQT4L0R8Wp6vBC4vibRmZWxcNna7faE6N8aLFy21onCrMLyjFFMBd4oOn4DmFaVaIpIOkrSOkmPSjq/2tez5rKpr39E7Wa24/KUGf8e8AtJN5GU7jgB+JdqBiWpBfh74HBgPXC/pGUR8XA1r2tmZtsb9o4iIi4BTgNeAjYBp0XE16oc18HAoxHxeES8AfwQOL7K17QmMmli6U0Xy7Wb2Y7LOz12IvBKRPwNsF7SflWMCaADeLroeH3aZgbARcceQGvL4CLGrS3iomMPqFNEZmPXsF1Pki4COklmP32XZP/s7wNzqxhXqTLmg0YuJZ0OnA4wdar3URpvCgPWnh5rVn15xihOAGYDDwBExAZJu2e/ZNTWA/sWHe8DbCg+ISIWA4sBOjs7XfZ8jBjJ2gjvXW1WG3kSxRsREZICQNKuVY4J4H5gRtrF1QucCPxhDa5rddTV3cuCpT309Q8AydqIBUt7AK+NMKunPGMU10n6R6Bd0ueAO6lymfGI2AKcBSwHHgGui4i11bym1d+i5eu2JYmCvv4BFi1fV6eIzAxy3FFExNclHQ68QjJO8ZcRcUe1A4uIW4Fbq30daxwbNvWNqN3MaiPPYPZfR8R5wB0l2swqZkp7G70lksKU9rY6RGNmBXm6ng4v0XZ0pQMxm3/kTNpaWwa1tbW2MP/ImXWKyMwgu3rs/wX+BJgu6cGip3YH/rPagdnYkmc2k6e8mjUmRZSeWSrpbcAk4FKguNbSqxHxYg1iy62zszNWrlxZ7zCsjAu7eliy4qlBC2HaWlu49GMHOgmY1ZGkVRHROdx5ZbueIuLliHgS+BvgxYj4VUT8CuiXdEjlQrWxrKu7d7skAZ7NZNZM8qyj+A7w3qLj10q0mQ1S6GoqNThd4NlMZs0hT6JQFPVPRcRWSXleZ+PU0IVz5Xg2k1lzyDPr6XFJZ0tqTb++ADxe7cCseZVaODeUwLOZzJpEnkRxBvABklIa64FDSIvxmZUyXJeSgJPnTPVAtlmTyLMy+zmSWktmuZRbOAfQ4SmvZk0nax3FlyLickl/C9tNWiEizq5qZNa05h85c7sxCk+HNWteWXcUj6TfvUDBBhlu8ZwXzpmNLWUX3DUTL7irnVIzmny3YNac8i64y+p6upkSXU4FEXHcDsZmTSyrFLgThdnYlNX19PX0+8eA3yLZ/hTgJODJKsZkdVbctdQ+sZUIeLmvP3OQ2ovnzMausokiIn4GIOmrEfH7RU/dLOmeqkdmdTG0a+mlzf3bnuvd1IcofZvpxXNmY1eedRSTJe1fOEi3J51cvZCsnoZbLBck6yCKuRS42diWpxTHnwF3Syqsxp4GfL5qEVld5elCCpL1EJ7RZDY+5Flwd7ukGcC706ZfRsTr1Q3L6iVrHKKgo72Ne8//cI0iMrN6G7brSdJEYD5wVkSsAaZK+uhoLippkaRfSnpQ0k2S2tP2aZL6JK1Ov/5hNNexkSu1y1wxdzOZjT95xii+C7wBvD89Xg9cPMrr3gH8dkT8DvBfwIKi5x6LiFnp1xmjvI6N0LzZHVz6sQPpaG9DwKSJrbS3tSKSOwmvlzAbf/KMUUyPiE9JOgkgIvokDR3PHJGI+EnR4Qrg46N5P6usebM7nAzMbJs8dxRvSGojnRUpaTpQyTGKzwC3FR3vJ6lb0s8k/V4Fr2NmZjsgzx3FRcDtwL6SlgBzgT8e7kWS7iRZqDfUBRHx4/ScC4AtwJL0uWeAqRHxgqT3AV2SDoiIV0q8/+mk5c6nTp2a459hw9VoMjMrJbPWU9rFtA+wGZhDMoV+RUQ8P+oLS6eS7HVxWERsLnPO3cCfR0RmISfXehqeazSZ2VB5az1ldj2lW6B2RcQLEfGvEXFLhZLEUcB5wHHFSULSZEkt6eP9gRl4N72KyKrRZGaWJc8YxQpJv1vh6/4dsDtwx5BpsL8PPChpDXADcEZEvFjha49L5RbSuUaTmQ0nzxjFocAZkp4EXiPpfop0ausOiYh3lmm/EbhxR9/Xyiu3kM41msxsOHkSxdFVj8Iq4sKuHq6972kGImiROOmQfbl43oFA+V3nvHjOzIaTtR/FLiSDze8EeoCrImJLrQKzkTn5yp9z72Nv9tINRPD9FU8BcPG8A73rnJntsLKzniT9COgH/p3kruJXEfGFGsaW23ie9dTV3cvCZWvZ1Ndf8vkWiccuPabGUZlZMxj1DnfAeyLiwPTNrgJ+UangrDJKTXkdamAMbHVrZvWVNetp25+o7nJqTMPtHQHJHYWZ2Whk3VEcJKmwIlpAW3pcmPX01qpHZ4MMXVk9XDlwgJMO2bcGkZnZWJa1FWr5WtNWc0O7mbK2JS2YO32PbbOezMx2VJ7psVZHhbuIUncPhW1JhyaLSRNbuejYAzyjycwqwomigeUZrPa2pGZWbU4UDSzPYLW3JTWzastT68nqZLg6TF5ZbWa14DuKBlBun4ismU0d7mYysxpxoqizoaU3ejf18cXr1wDl6zN5DwkzqyV3PdXRhV09g5JEwcDW4IKbepg3u4NLP3YgHe1tiOQuwknCzGrNdxR1dO19T5d97rU3kruIebM7nBjMrK58R1FHrsNkZs3AdxQ1UmrAukUqmyxcocnMGoXvKGqgsHCud1MfQTJgvWBpD3P2n1T2NSfPmVq7AM3MMjhR1ECphXN9/QM8+UIfp8yZut3dwylzprpGk5k1jLp0PUlaCHwO2Jg2fTkibk2fWwB8FhgAzo6I5fWIsZLKLZzbsKmPi+cd6KRgZg2tnmMU34yIrxc3SHoPcCJwADAFuFPSuyIiu45FAym1b3W5hXNT2tvqEKGZ2cg0WtfT8cAPI+L1iHgCeBQ4uM4x5XbylT/n+yue2jZAXdi3etrb22hrHVy13eU3zKxZ1DNRnCXpQUlXSyqM6nYAxYsL1qdt25F0uqSVklZu3Lix1Ck11dXdW3LxHMCKx1/ywjkza1pV63qSdCfwWyWeugD4DvBVkirZXwWuAD5D6VmhJeePRsRiYDFAZ2dn3RckLFq+ruxzAxFeOGdmTatqiSIiPpLnPElXArekh+uB4r079wE2VDi0qsiq9Op9q82smdWl60nS3kWHJwAPpY+XASdKeouk/YAZwC9qHd+OyBqY9r7VZtbM6jVGcbmkHkkPAocCfwYQEWuB64CHgduBM5tlxtP8I2duN2AN3rfazJpfXabHRsQfZTx3CXBJDcOpiML4Q6l9JczMmplrPVWQB6zNbCxyoiij3K5zZmbjjRPFEF3dvXzl5rW8tLl/W1uhiB/gZGFm406jrcyuq0KV1+IkUdDXP5C5VsLMbKxyoihSqsprsay1EmZmY5UTRZHhEoGL+JnZeDSuxyiGDli3T2wt2e0ELuJnZuPXuE0UhfGIQldT76Y+WieI1hbRPzC4dFR7WysLjzvAA9lmNi6N20RRajyif2vQ3tbKrm/ZydNizcxS4zZRlBuPeLmvn9UXHVHjaMzMGte4HcwuNzDtAWszs8HGbaIoVcTPA9ZmZtsbt11PLuJnZpbPuE0U4CJ+ZmZ5jNuuJzMzy8eJwszMMjlRmJlZJicKMzPL5ERhZmaZFBHDn9XgJG0EflXvOEp4B/B8vYPIybFWR7PE2ixxgmOtpP8VEZOHO2lMJIpGJWllRHTWO448HGt1NEuszRInONZ6cNeTmZllcqIwM7NMThTVtbjeAYyAY62OZom1WeIEx1pzHqMwM7NMvqMwM7NMThQVJmmhpF5Jq9OvY4qeWyDpUUnrJB1ZzzjTeBZJ+qWkByXdJKk9bZ8mqa/o3/AP9Y4VQNJR6Wf3qKTz6x1PMUn7SvqppEckrZX0hbS97M9DPUl6UlJPGtPKtG0PSXdI+u/0+6QGiHNm0We3WtIrks5plM9V0tWSnpP0UFFbyc9RiW+nP78PSnpvPWLeEe56qjBJC4HfRMTXh7S/B7gWOBiYAtwJvCsiBrZ7kxqRdATwbxGxRdJfA0TEeZKmAbdExG/XK7ahJLUA/wUcDqwH7gdOioiH6xpYStLewN4R8YCk3YFVwDzgk5T4eag3SU8CnRHxfFHb5cCLEXFZmognRcR59YpxqPRnoBc4BDiNBvhcJf0+8BvgXwq/L+U+xzSZ/SlwDMm/4W8i4pB6xT4SvqOoneOBH0bE6xHxBPAoSdKom4j4SURsSQ9XAPvUM55hHAw8GhGPR8QbwA9JPtOGEBHPRMQD6eNXgUeAZqthfzxwTfr4GpJE10gOAx6LiIZZXBsR9wAvDmku9zkeT5JQIiJWAO3pHxgNz4miOs5Kby2vLrp97wCeLjpnPY31H8lngNuKjveT1C3pZ5J+r15BFWn0z2+b9I5sNnBf2lTq56HeAviJpFWSTk/b9oqIZyBJfMCedYuutBNJ7soLGvFzhfKfY9P8DA/lRLEDJN0p6aESX8cD3wGmA7OAZ4ArCi8r8VZV7/cbJtbCORcAW4AladMzwNSImA2cC/xA0lurHesw6vL5jZSk3YAbgXMi4hXK/zzU29yIeC9wNHBm2oXSsCTtDBwHXJ82NernmqUpfoZLGdc73O2oiPhInvMkXQnckh6uB/YtenofYEOFQ9vOcLFKOhX4KHBYpANWEfE68Hr6eJWkx4B3ASurHG6Wunx+IyGplSRJLImIpQAR8WzR88U/D3UVERvS789Juomka+9ZSXtHxDNpl8hzdQ1ysKOBBwqfZ6N+rqlyn2PD/wyX4zuKChvS53gCUJgNsQw4UdJbJO0HzAB+Uev4ikk6CjgPOC4iNhe1T04HDpG0P0msj9cnym3uB2ZI2i/96/JEks+0IUgScBXwSER8o6i93M9D3UjaNR1wR9KuwBEkcS0DTk1POxX4cX0iLOkkirqdGvFzLVLuc1wGfDqd/TQHeLnQRdXoPOupwiR9j+R2OIAngc8XfhjSLp7PkHTznBMRt5V7n1qQ9CjwFuCFtGlFRJwh6X8Df0US5wBwUUTcXKcwt0lnjXwLaAGujohL6hzSNpI+CPw70ANsTZu/TPIfXMmfh3pJk/9N6eFOwA8i4hJJbweuA6YCTwGfiIihA7U1J2kiSd/+/hHxctpW9vesxrFdC3yIpErss8BFQBclPsf0j4m/A44CNgOnRUQ979Jzc6IwM7NM7noyM7NMThRmZpbJicJ9FqwqAAAEhUlEQVTMzDI5UZiZWSYnCjMzy+REYXUn6e1FVUB/PaQq6M4VvM5HJL2swdVID63U+5e5Zoukf6/Qe/2dpA+kj9crrfZb9PxOkjYNabtD0m+N8rrvlLQ6fTxL0j+N5v2s+XhlttVdRLxAMic+q/quSKZzb93+HUbkpxFR0WJ3knYqKq44SFodeNS1siRNBmZHxFkjeM2uwO4R8evRXr8gIlZLmi6pIyJ6K/W+1th8R2ENK/1L9iEl+2E8AOxb/BezpBMLf91K2kvSUkkrJf0iXfk60utcpWQvidsk7ZI+N0PS8rR43j2S3pW2f1/SFZJ+CnxN0p6S7pL0gKT/l94VtQ/9K1/S+Wl8D0r6y7Rt9/Saa9I4Pl4izE8wuGhj4f0mSvqJpNNKvObDwL+l562XdImkFZLul/Te9HWPSfpces4ESd9IY+gpEwck5TI+lffztebnRGGN7j3AVWmBwqy/YL8NXB4RnSR7QJTrHjl0SNfTtLR9JvCtiDgA6OPN0tCLgT+JiPcBC0hW1hZMJ6mR9SWSley3p4X2biXZc2SQdGX5VJK9CGYBH0i7ko4BnoyIg9I9De4oEfdckj0uiu1O8p/2NRHx3RKvORq4vej4yYiYQ1JS/iqS0hcfAL6aPv8Jks/7IJJ9P74pqVQF2ZVU4C7Jmoe7nqzRPRYR9+c47yPAzKSHCoBJktoiom/Iedt1PUl6J8leFz1p0ypgWjoGMAe4seh9i39nri/qCvsgcAlARNwi6dUSMR5B8p93d3q8G0mxxfuAyyRdBtwcEfeWeO3ewMYhbbcAX4uIH5U4nzT2s4uOC7WxeoCdIuI14DVJW5VUvf0gSTmPAeDXkv4D6CTZMKrYc5RIhDZ2OVFYo3ut6PFWBpdq3qXosYCD002NdsTrRY8HSH43BDwfEbNyxFaqhPRQAi6OiKu2e0LqJLmzWCTploj42pBT+hj87wW4Fzha0nWFyr9F7zcTeGLI2Enh37iVwf/erbz5781jlzQeGyfc9WRNI/3r/aV03GACSddJwZ3AmYUDSeX+cx/J9V4CnpF0QvqeEyQdVOb0/yDp8ip0Me1e4pzlwGfTQWYk7SPpHZI6SAbwvwd8Ayi1l/IjwDuHtH2ZJFl9u8T5R1FiTGMY95BUOG6RtBdJd1eponXvorGqtVqVOVFYszmPpN/9LpL6/gVnAnPTQeKHgc+Vef3QMYoTypxXcCJwhqQ1wFqSvTtKuQj4A0kPkAwiP8vgOw4i4lbgBmCFpB6SCqO7kYwJ3J9OQf0SMPRuAuBfSaqUDnUW8DZJXyO5KyjcKRzF4PGJPG4AfgmsIUm850ZEqT0pDk3jsXHC1WPNKiCdJbUlIrYoKTn+rXRgvVLvL5K7lqPTnfNKnfM+4G9JEtU9EVHxPdkltQE/Jdkhb6DS72+NyYnCrAIkvZtkY50Wkr/qz4iIobOURnuN9wOvRsR23T6SziS5qzo7Iu6s5HWHXGcmyZ7Q91TrGtZ4nCjMzCyTxyjMzCyTE4WZmWVyojAzs0xOFGZmlsmJwszMMjlRmJlZpv8PHve8+xg2OpEAAAAASUVORK5CYII=\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "idx_train = np.arange(0,100)\n", "\n", @@ -695,27 +700,9 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The RMSE is 0.06417642021013359 kJ/mol\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEKCAYAAAAMzhLIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X2UXXV97/H3J8MgE4gOaKQkkBuIMS4xJdEppEa7RIQAVQheUShUil6RWygiNkKEFqyglIhY215bKFgqiPIQRqBADFSkpQaZMCHDg2l5iJAJQkACCFOYTL73j71PODM5Z8+emfM483mtNWvO/p19zv7OWZN8Z/8evj9FBGZmZuVMqncAZmbW2JwozMwskxOFmZllcqIwM7NMThRmZpbJicLMzDI5UZiZWSYnCjMzy+REYWZmmXaodwCV8La3vS1mzpxZ7zDMzJrK6tWrn4uIqcOdNy4SxcyZM+nq6qp3GGZmTUXSr/Kc564nMzPL5ERhZmaZnCjMzCyTE4WZmWVyojAzs0zjYtaTmdlE09ndy7IV69i4uY9p7W0sWTSHxfOnV+VaThRmZk2ms7uXpct76OsfAKB3cx9Ll/cAVCVZuOvJzKzJLFuxbluSKOjrH2DZinVVuZ4ThZlZk9m4uW9E7WPlRGFm1mSmtbeNqH2snCjMzJrMkkVzaGttGdTW1trCkkVzqnI9D2abmTWZwoC1Zz2ZmVlZi+dPr1piGMpdT2ZmlsmJwszMMjlRmJlZJicKMzPL5ERhZmaZnCjMzCyTE4WZmWVyojAzs0w1SRSSrpD0rKQHi9p2k7RS0n+n33dN2yXpO5IelbRW0ntrEaOZmZVWqzuKfwYOHdJ2FnBnRMwG7kyPAQ4DZqdfJwHfrVGMZmZWQk0SRUTcDfxmSPORwJXp4yuBxUXt/xKJVUC7pD1qEaeZmW2vnmMUu0fE0wDp97en7dOBp4rO25C2DSLpJEldkro2bdpU9WDNzCaqRhzMVom22K4h4tKI6IiIjqlTp9YgLDOziameieKZQpdS+v3ZtH0DsFfReXsCG2scm5mZpeqZKG4CTkgfnwD8uKj90+nspwXAi4UuKjMzq72a7Ech6RrgQ8DbJG0AzgUuBK6V9FngSeDo9PRbgcOBR4FXgRNrEaOZmZVWk0QREceWeeqgEucGcEp1IzIzs7wacTDbzMwaiBOFmZllcqIwM7NMThRmZpbJicLMzDI5UZiZWSYnCjMzy+REYWZmmZwozMwskxOFmZllcqIwM7NMw9Z6ktQBfBCYBvQBDwJ3RMTQHevMzBpKZ3cvy1asY+PmPqa1t7Fk0RwWz99uHzQbRtlEIelPgNOAJ4DVwDpgJ+ADwJmSHgT+IiKerEGcZmYjck5nD1evenLbrme9m/tYurwHwMlihLLuKHYGFkZEX6knJc0DZpOUCDczaxid3b2DkkRBX/8Ay1asc6IYobKJIiL+PuuFEbGm8uGYmY3dshXrtt8/ObVxc8m/fS1DVtfTd7JeGBGnVT4cM7ORKTUOkZUMprW31TC68SGr62l1zaIwMxuFzu5eli7voa9/AHhjHKJ9cisvvNq/3fkCliyaU+Mom19W19OVxceSpiTN8duqR2VmNozO7l6+dO0DDMTgTqa+/gHetMMk2lpbtiUQSJLEcQtmeHxiFIZdRyHpPZK6SabFPixptaR9x3phSXMkrSn6eknS6ZLOk9Rb1H74WK9lZuNL4U5iaJIoeLGvn298fC7T29sQML29jUs+NY/zF8+tbaDjRJ49sy8FzoiInwJI+hBwGfD+sVw4ItYB89L3bAF6gRuBE4FLIuKbY3l/Mxt/CuMRvcMMSE9rb2Px/Om+e6iQPIli50KSAIiIuyTtXOE4DgIei4hfSarwW5vZeDB0PKKcttYWj0NUWJ5E8bikvwC+nx4fT7IIr5KOAa4pOj5V0qeBLuBLEfFCha9nZk0i710EQIvENz4+13cSFZan1tNngKnAcpKuoakk3UMVIWlH4AjgurTpu8Askm6pp4GLy7zuJEldkro2bdpUqXDMrIGc09nDF3+0JleSaGtt4eJP7uckUQWKMoNBNQtAOhI4JSIOKfHcTOCWiHhP1nt0dHREV1dXdQI0s5oayR1EwXTXcRoVSasjomO48/IWBfwKMLP4/Ij43bEEWORYirqdJO0REU+nh0eRzLYyswnguMt+zj2P5a832tba4q6mGsgzRnE1sAToAbZW8uKSJgMHA58var4orSMVwPohz5nZONTZ3cuXr3+A1wfy93D4LqJ28iSKTRFxUzUuHhGvAm8d0vbH1biWmTWmzu5ellz/AP05k4SASz41zwmihvIkinMl/RNwJ/BaoTEillctKjMb15Kprmvp6x9ZJ4VXV9dHnkRxIvAuoJU3up6CZBaUmdmInNPZw1WrRr47gbua6idPotgvIrzu3czGrLO7d8RJYuGs3bj6c79fpYgsjzzrKFZJenfVIzGzcW/ZinUjOt9JojHkuaP4AHCCpCdIxihEUkW2UtNjzWyCyLtpUHtbK+cdsa+7mRpEnkRxaNWjMLNxaeimQuX2iSg4fsEMV3htQFk73HUB9wC3AXdFxP/ULCoza3qlNhVqnSRaJomBrdtPhV04azcniQaVdUexgKTb6VDgq5KeB1YAt0XEf9UiODNrLsV3EJOk7faL6N8atLe1IrHtzsLdTI0va4e7LcBd6ReS9gAOA86XNBv4eUT8aQ1iNLMG19ndy1eWr+XVonURWZsKPXHhH9YqNKuAPGMUAKT1l64ArpA0CfBUBDOjs7uXM65dQ4nepJKmtbdVNyCruKwxiptJFtaV8hrwmKQnI+KpqkRmZg2ts7uXM29Yy2tb8q+u9qZCzSnrjiJrK9IdgH2Ba/GdhdmEM5K7iBaJrRFM88rqppU1RvEzAEnvi4jVxc9J+lhEfEeS11KYTSAjrfIq8GZC40CeMYrLJJ0QET0Ako4FTgdujoj/U9XozKxhHPytu/jvZ18Z0WtcwG98yJMoPgFcL+k4kumynwa2243OzMav4y77+YiThBfPjR/DJoqIeFzSMUAn8BRwSETk36PQzJrWOZ09XHPvU2WnupayY4u46BPubhpPsmY99TB41tNuQAtwr6RKboVqZg3oXWffyv+MYMc5cBG/8SrrjuKjNYvCzBpGZ3cvX/zRmrJz40txghjfshLF8xHx26wXS9pluHPMrHkU6jM5SVixrETxY0lrgB8DqyPiFQBJ+wAHAp8ELgOuH0sAktYDLwMDwJaI6JC0G/AjYCawHvhkRLwwluuYWWmd3b2cd9NDbO4rX9W1FCeIiSNrHcVBkg4HPg8slLQrsAVYB/wrcEJE/LpCcRwYEc8VHZ8F3BkRF0o6Kz0+s0LXMrNUZ3cvS657gP689TfwbKaJKHPWU0TcCtxao1iKHQl8KH18JUlhQicKswoazd7VO7XISWICyl0UsIoC+ImkAP4xIi4Fdk+LEBIRT0t6e10jNBtHRrNwDmD3KTty79kHVyEia3SNkCgWRsTGNBmslPTLPC+SdBJwEsCMGTOqGZ/ZuHHABSt55uXXc58/3fWZjAZIFBGxMf3+rKQbgf2BZyTtkd5N7AE8W+J1lwKXAnR0dIxssrfZBDPSGk2tLWKZF81ZatJwJ0iaJelN6eMPSTpNUnslLi5pZ0lTCo9JSoM8CNwEnJCedgLJzCszG4VzOns4/UdrcieJXSe3OknYIHnuKG4AOiS9A7ic5D/xHwCHV+D6uwM3SirE8oOIuF3SfcC1kj4LPAkcXYFrmU04Ix2w9owmKyVPotgaEVskHQV8OyL+VlJ3JS4eEY8D+5Vofx44qBLXMJuoOrt7udpJwiogT6LoT0uLnwB8LG1rrV5IZjZaoyni59lMNpw8ieJE4GTggoh4QtLewFXVDcvMRmo00169utryyFNm/GFJZwIz0uMngAurHZiZ5TfS/SJEsqmQu5osj2EThaSPkeyfvSOwt6R5wF9FxBHVDs7MyhtNjSYnCBuNPF1P55GsbbgLICLWpN1PZlYnI63RJGCaF8/ZKOVJFFsi4sV0CmuBF7iZ1UFndy/LVqyjd3P+TSY9m8nGKk+ieFDSHwEtkmYDpwH/Wd2wzGyoczp7uHrVkyP6K23223d2krAxG3ZlNvBnwL7Aa8A1wEvA6dUMyswGK6yJGOmGQivP+FC1QrIJJM+sp1eBs9MvM6uDZSvW5U4S7W2tnHfEvh6LsIopmygkfTsiTpd0MyXGJDzryax6RtPNtOvkVs79mBOEVV7WHcX30+/frEUgZpYYSX0mAZd8ap6Tg1VV1laoq9PvP6tdOGZ2zb1P5TqvsCbCScKqLc+Cux6273p6EegCzk8L+JlZhQxXp8lrIqzW8kyPvQ0YICktDnAMye/qi8A/80ahQDOrgBapbLKY3t7GPWd9uMYR2USXJ1EsjIiFRcc9ku6JiIWSjq9WYGbjWWHh3MbNfdvdHRx7wF4lxyhaJokli+bUOlSzXOsodpF0QOFA0v7ALunhlqpEZTaOdXb3snR5D72b+wigd3MfS5f30NndC8D5i+dy/IIZFNdC2HnHFi4+2rvOWX0ohusPlTqA7/FGcngZ+CzwMPCHEXFtVSPMoaOjI7q6uuodhlkuCy/8t5IlONytZLUmaXVEdAx3XmbXk6RJwD4RMVfSW0gSy+aiU+qeJMyazcYydZrKtZvVW2bXU0RsBU5NH784JEmY2ShMa28bUbtZveUZo1gp6c8l7SVpt8LXWC+cvt9PJT0i6SFJX0jbz5PUK2lN+nX4WK9l1kiWLJpDW2vLoLa21hYPVFvDyjPr6TPp91OK2gLYZ4zX3gJ8KSLulzQFWC1pZfrcJRHhFeE2LhUGpMvNejJrNHmKAlZlk6KIeBp4On38sqRHAP9LsaaVNeV1qMXzpzsxWNMYtutJ0mRJ50i6ND2eLemjlQxC0kxgPnBv2nSqpLWSrpC0ayWvZVYNw015NWtmecYovge8Drw/Pd4AnF+pACTtAtwAnB4RLwHfBWYB80juOC4u87qTJHVJ6tq0aVOlwjEblWUr1tHXPzCora9/gGUr1tUpIrPKyZMoZkXERUA/QET0waC1QKMmqZUkSVwdEcvT938mIgbSGVeXkezXvZ2IuDQiOiKiY+rUqZUIx2zUPOXVxrM8g9mvS2ojLQwoaRbJbndjomQT7suBRyLiW0Xte6TjFwBHAQ+O9VpmldTZ3ct5Nz3E5r5+INkH4i1trduOi3nKq40HeRLFucDtwF6SrgYWAn9SgWsvBP6YpHbUmrTtK8CxkuaRJKb1wOcrcC2ziujs7mXJdQ/Qv/WNigYvvNpPyyTROkmD2j3l1caLPLOeVkq6H1hA0uX0hYh4bqwXjoj/oHQX1q1jfW+zalm2Yt2gZFAwsDV48+RWJu+4g6e82riT544CYCfghfT8d0siIu6uXlhmjSlrzGHzq/10/+UhNYzGrDbybFz018CngIeArWlzAE4UNuFMa28rWdCv8JzZeJTnjmIxMCcixjyAbdbsliyas90YBUBri/eKsPErz/TYx4HWagdi1gwWz5/OsqP3o73tjX8Su05uZdknvFeEjV957iheBdZIupOiabERcVrVojJrYC6/YRNNnkRxU/plZmYTUNlEIenNEfFSRFxZ4rkZ1Q3LzMwaRdYYxV2FB2m3U7HOqkRjZmYNJytRFC+GG7pRUUVqPZmZWePLShRR5nGpYzMzG6eyBrPfLukMkruHwmPSY5drNTObILISxWXAlBKPAf6pahGZmVlDKZsoIuKrtQzEzMwaU56V2WZmNoE5UZiZWSYnCjMzy5S1MvuMcs8BFG9famZm41fWrKfCLKc5wO/xRr2nj+G9KKwBnNPZwzX3PsVABC0Sxx6wF+cvnlvvsMzGnWFnPUn6CfDeiHg5PT4PuK4m0ZmVcU5nD1etenLb8UDEtmMnC7PKyjNGMQN4vej4dWBmVaIpIulQSeskPSrprGpfz5rLNfc+NaJ2Mxu9PGXGvw/8QtKNJKU7jgL+pZpBSWoB/h44GNgA3Cfppoh4uJrXteYxEKWryJRrN7PRGzZRRMQFkm4DPpg2nRgR3dUNi/2BRyPicQBJPwSOBJwoDIAWqWRSaJHrVZpVWt7psZOBlyLib4ANkvauYkwA04HiPoQNaZsZAMcesNeI2s1s9IZNFJLOBc4ElqZNrcBV1QyK0mXMB/35KOkkSV2SujZt2lTlcKzRnL94LscvmLHtDqJF4vgFMzyQbVYFecYojgLmA/cDRMRGSVOyXzJmG4DiPw33BDYWnxARlwKXAnR0dLhjepzo7O5l2Yp1bNzcx7T2NpYsmlN2f+rzF891YjCrgTyJ4vWICEkBIGnnKscEcB8wO+3i6gWOAf6oBte1Ours7mXp8h76+gcA6N3cx9LlPQBlk4WZVV+eMYprJf0j0C7pc8AdVLnMeERsAU4FVgCPANdGxEPVvKbV37IV67YliYK+/gGWrVhXp4jMDPLNevqmpIOBl0hWaf9lRKysdmARcStwa7WvY41j4+a+EbWbWW0Mmygk/XVEnAmsLNFmVjHT2tvoLZEUprW31SEaMyvI0/V0cIm2wyodiNmSRXNoa20Z1NbW2sKSRXPqFJGZQXb12P8L/CkwS9LaoqemAP9Z7cBsfMkzm6lwnHfWk5nVhqJMyQNJbwF2Bb4BFNdaejkiflOD2HLr6OiIrq6ueodhZZzT2cPVq54ctBCmrbWFb3x8rpOAWR1JWh0RHcOdV7brKSJejIj1wN8Av4mIX0XEr4B+SQdULlQbzzq7e7dLEuDZTGbNJM86iu8C7y06fqVEm9kgha6mUoPTBZ7NZNYc8iQKRVH/VERslZTndTZBDV04V45nM5k1hzyznh6XdJqk1vTrC8Dj1Q7MmlephXNDCTybyaxJ5EkUJwPvJymlsQE4ADipmkFZcxuuS0nAcQtmeCDbrEnkWZn9LEmtJbNcyi2cA5juKa9mTSdrHcWXI+IiSX8L201aISJOq2pk1rSWLJqz3RiFp8OaNa+sO4pH0u9eoGCDDLd4zgvnzMaXsgvumokX3NVOqRlNvlswa055F9xldT3dTIkup4KIOGKUsVkTyyoF7kRhNj5ldT19M/3+ceB3eGP702OB9VWMyRrI0G6mcoPUXjxnNn6VTRQR8TMASV+LiD8oeupmSXdXPTKru1I7zonSt5lePGc2fuVZYT1V0j4R8ThAuj3p1OqGZfWUVX4jYLtk4VLgZuNbnkTxReAuSYXV2DOBz1ctIqurPOU3gmQ9hGc0mU0MeRbc3S5pNvCutOmXEfFadcOyeslTfmN6exv3nPXhGkVkZvU2bAkPSZOBJcCpEfEAMEPSR8dyUUnLJP1S0lpJN0pqT9tnSuqTtCb9+oexXMdGbrhBaXczmU08eWo9fQ94Hfj99HgDcP4Yr7sSeE9E/C7wX8DSoucei4h56dfJY7yOjVDWoPT09javlzCbgPIkilkRcRHQDxARfSTjmaMWET+JiC3p4Spgz7G8n1VOuX2rv/2pedxz1oedJMwmoDyJ4nVJbaQTXSTNAio5RvEZ4Lai470ldUv6maQPVvA6lsPi+dP5xsfnMr29DeG7CDPLN+vpXOB2YC9JVwMLgT8Z7kWS7iBZqDfU2RHx4/Scs4EtwNXpc08DMyLieUnvAzol7RsRL5V4/5NIy53PmDEjx49hw9VoKlg8f7oTg5ltk1nrSZJIuoVeBRaQdDmtiojnxnxh6QSSvS4OiohXy5xzF/DnEZFZyMm1nobnGk1mNlTeWk+ZXU/pFqidEfF8RPxrRNxSoSRxKHAmcERxkpA0VVJL+ngfYDbeTa8ismo0mZllyTNGsUrS71X4un8HTAFWDpkG+wfAWkkPANcDJ0fEbyp87Qmp3LRX12gys+HkGaM4EDhZ0nrgFdIKDunU1lGJiHeUab8BuGG072vllSvo5xpNZjacPInisKpHYRVxTmcP19z7FAMRtEgce8BenL94LlB+1zkvnjOz4WTtR7ETyWDzO4Ae4PKitQ/WYI677Ofc89gbvXQDEVy16kkAzl8817vOmdmolZ31JOlHJIvs/p3kruJXEfGFGsaW20Se9dTZ3ct5Nz3E5r7+ks+3SDz2jcNrHJWZNYMx73AHvDsi5qZvdjnwi0oFZ5WRp9LrwDjY6tbM6itr1tO2P1Hd5dSY8lR6bdGYqq2YmWXeUewnqbAiWkBbelyY9fTmqkdng+TdlrTYsQfsVYPIzGw8y9oKtaXcc1Z7I9mWtGDhrN22zXoyMxutPNNjrY5Gui0pwK6TWzn3Y/t6RpOZVYQTRQPztqRm1gicKBqYtyU1s0aQp9aT1Ym3JTWzRuA7igZQbp+IrJlN093NZGY14kRRZ0NLb/Ru7uNL1z0AlK/P5D0kzKyW3PVUR+d09gxKEgUDW4Ozb+zxtqRm1hB8R1FH19z7VNnnXnk9uYvwtqRmVm++o6gj12Eys2bgO4oaKTVg3SKVTRau0GRmjcJ3FDVQWDjXu7mPIBmwXrq8hwX77Fr2NcctmFG7AM3MMjhR1ECphXN9/QOsf76P4xfM2O7u4fgFM1yjycwaRl26niSdB3wO2JQ2fSUibk2fWwp8FhgATouIFfWIsZLKLZzbuLmP8xfPdVIws4ZWzzGKSyLim8UNkt4NHAPsC0wD7pD0zojIrmPRQErtW11u4dy09rY6RGhmNjKN1vV0JPDDiHgtIp4AHgX2r3NMuR132c+5atWT2waoC/tWz3xrG22tg6u2u/yGmTWLeiaKUyWtlXSFpMKo7nSgeHHBhrRtO5JOktQlqWvTpk2lTqmpzu7ekovnAFY9/oIXzplZ06pa15OkO4DfKfHU2cB3ga+RVMn+GnAx8BlKzwotOX80Ii4FLgXo6Oio+4KEZSvWlX1uIMIL58ysaVUtUUTER/KcJ+ky4Jb0cANQvHfnnsDGCodWFVmVXr1vtZk1s7p0PUnao+jwKODB9PFNwDGS3iRpb2A28ItaxzcaWQPT3rfazJpZvcYoLpLUI2ktcCDwRYCIeAi4FngYuB04pVlmPC1ZNGe7AWvwvtVm1vzqMj02Iv4447kLgAtqGE5FFMYfSu0rYWbWzFzrqYI8YG1m45ETRRnldp0zM5tonCiG6Ozu5as3P8QLr/ZvaysU8QOcLMxswmm0ldl1VajyWpwkCvr6BzLXSpiZjVdOFEVKVXktlrVWwsxsvHKiKDJcInARPzObiCb0GMXQAev2ya0lu53ARfzMbOKasImiMB5R6Grq3dxH6yTR2iL6BwaXjmpva+W8I/b1QLaZTUgTNlGUGo/o3xq0t7Wy85t28LRYM7PUhE0U5cYjXuzrZ825h9Q4GjOzxjVhB7PLDUx7wNrMbLAJmyhKFfHzgLWZ2fYmbNeTi/iZmeUzYRMFuIifmVkeE7bryczM8nGiMDOzTE4UZmaWyYnCzMwyOVGYmVkmRcTwZzU4SZuAX9U7jhLeBjxX7yBycqzV0SyxNkuc4Fgr6X9FxNThThoXiaJRSeqKiI56x5GHY62OZom1WeIEx1oP7noyM7NMThRmZpbJiaK6Lq13ACPgWKujWWJtljjBsdacxyjMzCyT7yjMzCyTE0WFSTpPUq+kNenX4UXPLZX0qKR1khbVM840nmWSfilpraQbJbWn7TMl9RX9DP9Q71gBJB2afnaPSjqr3vEUk7SXpJ9KekTSQ5K+kLaX/X2oJ0nrJfWkMXWlbbtJWinpv9PvuzZAnHOKPrs1kl6SdHqjfK6SrpD0rKQHi9pKfo5KfCf9/V0r6b31iHk03PVUYZLOA34bEd8c0v5u4Bpgf2AacAfwzogY2O5NakTSIcC/RcQWSX8NEBFnSpoJ3BIR76lXbENJagH+CzgY2ADcBxwbEQ/XNbCUpD2APSLifklTgNXAYuCTlPh9qDdJ64GOiHiuqO0i4DcRcWGaiHeNiDPrFeNQ6e9AL3AAcCIN8LlK+gPgt8C/FP69lPsc02T2Z8DhJD/D30TEAfWKfSR8R1E7RwI/jIjXIuIJ4FGSpFE3EfGTiNiSHq4C9qxnPMPYH3g0Ih6PiNeBH5J8pg0hIp6OiPvTxy8DjwDNVsP+SODK9PGVJImukRwEPBYRDbO4NiLuBn4zpLnc53gkSUKJiFgFtKd/YDQ8J4rqODW9tbyi6PZ9OvBU0TkbaKz/SD4D3FZ0vLekbkk/k/TBegVVpNE/v23SO7L5wL1pU6nfh3oL4CeSVks6KW3bPSKehiTxAW+vW3SlHUNyV17QiJ8rlP8cm+Z3eCgnilGQdIekB0t8HQl8F5gFzAOeBi4uvKzEW1W932+YWAvnnA1sAa5Om54GZkTEfOAM4AeS3lztWIdRl89vpCTtAtwAnB4RL1H+96HeFkbEe4HDgFPSLpSGJWlH4AjgurSpUT/XLE3xO1zKhN7hbrQi4iN5zpN0GXBLergB2Kvo6T2BjRUObTvDxSrpBOCjwEGRDlhFxGvAa+nj1ZIeA94JdFU53Cx1+fxGQlIrSZK4OiKWA0TEM0XPF/8+1FVEbEy/PyvpRpKuvWck7RERT6ddIs/WNcjBDgPuL3yejfq5psp9jg3/O1yO7ygqbEif41FAYTbETcAxkt4kaW9gNvCLWsdXTNKhwJnAERHxalH71HTgEEn7kMT6eH2i3OY+YLakvdO/Lo8h+UwbgiQBlwOPRMS3itrL/T7UjaSd0wF3JO0MHEIS103ACelpJwA/rk+EJR1LUbdTI36uRcp9jjcBn05nPy0AXix0UTU6z3qqMEnfJ7kdDmA98PnCL0PaxfMZkm6e0yPitnLvUwuSHgXeBDyfNq2KiJMl/W/gr0jiHADOjYib6xTmNumskW8DLcAVEXFBnUPaRtIHgH8HeoCtafNXSP6DK/n7UC9p8r8xPdwB+EFEXCDprcC1wAzgSeDoiBg6UFtzkiaT9O3vExEvpm1l/53VOLZrgA+RVIl9BjgX6KTE55j+MfF3wKHAq8CJEVHPu/TcnCjMzCyTu57MzCyTE4WZmWVyojAzs0xOFGZmlsmJwszMMjlRWN1JemtRFdBfD6kKumMFr/MRSS9qcDXSAyv1/mWu2SLp3yv0Xn8n6f3p4w1Kq/0WPb+DpM1D2lZK+p0xXvcdktakj+dJ+qexvJ+oUg23AAAEKklEQVQ1H6/MtrqLiOdJ5sRnVd8VyXTurdu/w4j8NCIqWuxO0g5FxRUHSasDj7lWlqSpwPyIOHUEr9kZmBIRvx7r9QsiYo2kWZKmR0Rvpd7XGpvvKKxhpX/JPqhkP4z7gb2K/2KWdEzhr1tJu0taLqlL0i/Sla8jvc7lSvaSuE3STulzsyWtSIvn3S3pnWn7VZIulvRT4OuS3i7pTkn3S/p/6V1R+9C/8iWdlca3VtJfpm1T0ms+kMbxiRJhHs3goo2F95ss6SeSTizxmg8D/5aet0HSBZJWSbpP0nvT1z0m6XPpOZMkfSuNoadMHJCUy/hU3s/Xmp8ThTW6dwOXpwUKs/6C/Q5wUUR0kOwBUa575MAhXU8z0/Y5wLcjYl+gjzdKQ18K/GlEvA9YSrKytmAWSY2sL5OsZL89LbR3K8meI4OkK8tnkOxFMA94f9qVdDiwPiL2S/c0WFki7oUke1wUm0Lyn/aVEfG9Eq85DLi96Hh9RCwgKSl/OUnpi/cDX0ufP5rk896PZN+PSySVqiDbRQXukqx5uOvJGt1jEXFfjvM+AsxJeqgA2FVSW0T0DTlvu64nSe8g2euiJ21aDcxMxwAWADcUvW/xv5nrirrCPgBcABARt0h6uUSMh5D8592dHu9CUmzxXuBCSRcCN0fEPSVeuwewaUjbLcDXI+JHJc4njf20ouNCbaweYIeIeAV4RdJWJVVvP0BSzmMA+LWk/wA6SDaMKvYsJRKhjV9OFNboXil6vJXBpZp3KnosYP90U6PReK3o8QDJvw0Bz0XEvByxlSohPZSA8yPi8u2ekDpI7iyWSbolIr4+5JQ+Bv+8APcAh0m6tlD5t+j95gBPDBk7KfyMWxn8827ljZ83j53SeGyCcNeTNY30r/cX0nGDSSRdJwV3AKcUDiSV+899JNd7AXha0lHpe06StF+Z0/+DpMur0MU0pcQ5K4DPpoPMSNpT0tskTScZwP8+8C2g1F7KjwDvGNL2FZJk9Z0S5x9KiTGNYdxNUuG4RdLuJN1dpYrWvZPGqtZqVeZEYc3mTJJ+9ztJ6vsXnAIsTAeJHwY+V+b1Q8cojipzXsExwMmSHgAeItm7o5RzgT+UdD/JIPIzDL7jICJuBa4HVknqIakwugvJmMB96RTULwND7yYA/pWkSulQpwJvkfR1kruCwp3CoQwen8jjeuCXwAMkifeMiCi1J8WBaTw2Qbh6rFkFpLOktkTEFiUlx7+dDqxX6v1FctdyWLpzXqlz3gf8LUmiujsiKr4nu6Q24KckO+QNVPr9rTE5UZhVgKR3kWys00LyV/3JETF0ltJYr/H7wMsRsV23j6RTSO6qTouIOyp53SHXmUOyJ/Td1bqGNR4nCjMzy+QxCjMzy+REYWZmmZwozMwskxOFmZllcqIwM7NMThRmZpbp/wO17cH4QILV3wAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "estimator = ARMP(iterations=3000, learning_rate=0.075, l1_reg=0.0, l2_reg=0.0, scoring_function=\"rmse\")\n", "\n", @@ -780,7 +767,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.6" + "version": "3.7.2" } }, "nbformat": 4, diff --git a/qml/__init__.py b/qml/__init__.py index 5fca65a7f..aaa25a7ca 100644 --- a/qml/__init__.py +++ b/qml/__init__.py @@ -41,6 +41,7 @@ from . import representations from . import qmlearn from . import utils +from . import helpers from .utils.compound import Compound __author__ = "Anders S. Christensen" diff --git a/qml/helpers/__init__.py b/qml/helpers/__init__.py new file mode 100644 index 000000000..4e0f7e8cf --- /dev/null +++ b/qml/helpers/__init__.py @@ -0,0 +1,23 @@ +# MIT License +# +# Copyright (c) 2017 Anders S. Christensen +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from .helpers import * diff --git a/qml/helpers/helpers.py b/qml/helpers/helpers.py new file mode 100644 index 000000000..a3d799c3c --- /dev/null +++ b/qml/helpers/helpers.py @@ -0,0 +1,75 @@ +# MIT License +# +# Copyright (c) 2017-2019 Anders Steen Christensen, Jakub Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np + +def get_BoB_groups(asize, sort=True): + """ + Get starting and ending indices of bags in Bags of Bonds representation. + + :param asize: Atomtypes and their maximal numbers in the representation + :type asize: dictionary + :param sort: Whether to sort indices as usually automatically done + :type sort: bool + """ + if sort: + asize = {k: asize[k] for k in sorted(asize, key=asize.get)} + n = 0 + low_indices = {} + high_indices = {} + for i, (key1, value1) in enumerate(asize.items()): + for j, (key2, value2) in enumerate(asize.items()): + if j == i: # Comparing same-atoms bonds like C-C + new_key = key1 + key2 + low_indices[new_key] = n + n += int(value1 * (value1+1) / 2) + high_indices[new_key] = n + elif j >= i: # Comparing different-atoms bonds like C-H + new_key = key1 + key2 + low_indices[new_key] = n + n += int(value1 * value2) + high_indices[new_key] = n + return low_indices, high_indices + +def compose_BoB_sigma_vector(sigmas_for_bags, low_indices, high_indices): + """ + Create a vector of per-feature kernel widths. + + In BoB features are grouped by bond types, so a vector of per-group kernel + width would suffice for the computation, however having a per-feature + vector is easier for improving computations with Fortran. + + :param sigmas_for_bags: Kernel widths for different bond types + :type sigmas_for_bags: dictionary + :param low_indices: Starting indices for different bond types + :type low_indices: dictionary + :param high_indices: End indices for different bond types + :type high_indices: dictionary + :return: A vector of per-feature kernel widths + :rtype: numpy array + + """ + length = high_indices[list(sigmas_for_bags.keys())[-1]] + sigmas = np.zeros(length) + for group in sigmas_for_bags: + sigmas[low_indices[group]:high_indices[group]] = sigmas_for_bags[group] + return sigmas diff --git a/qml/kernels/fkernels.f90 b/qml/kernels/fkernels.f90 index 8daaf512e..3054566a0 100644 --- a/qml/kernels/fkernels.f90 +++ b/qml/kernels/fkernels.f90 @@ -555,6 +555,28 @@ subroutine fgaussian_kernel_symmetric(x, n, k, sigma) end subroutine fgaussian_kernel_symmetric +subroutine fgaussian_sigmas_kernel(a, na, b, nb, k, sigmas) + implicit none + double precision, dimension(:,:), intent(in) :: a + double precision, dimension(:,:), intent(in) :: b + double precision, dimension(:), intent(in) :: sigmas + integer, intent(in) :: na, nb + double precision, dimension(:,:), intent(inout) :: k + double precision, allocatable, dimension(:) :: temp + integer :: i, j + + allocate(temp(size(a, dim=1))) + !$OMP PARALLEL DO PRIVATE(temp) COLLAPSE(2) + do i = 1, nb + do j = 1, na + temp(:) = a(:,j) - b(:,i) + k(j,i) = product(exp(-abs(temp * temp / (2 * sigmas * sigmas)))) + enddo + enddo + !$OMP END PARALLEL DO + deallocate(temp) +end subroutine fgaussian_sigmas_kernel + subroutine flaplacian_kernel(a, na, b, nb, k, sigma) implicit none @@ -616,6 +638,28 @@ subroutine flaplacian_kernel_symmetric(x, n, k, sigma) end subroutine flaplacian_kernel_symmetric +subroutine flaplacian_sigmas_kernel(a, na, b, nb, k, sigmas) + implicit none + double precision, dimension(:,:), intent(in) :: a + double precision, dimension(:,:), intent(in) :: b + double precision, dimension(:), intent(in) :: sigmas + integer, intent(in) :: na, nb + double precision, dimension(:,:), intent(inout) :: k + double precision, allocatable, dimension(:) :: temp + integer :: i, j + + allocate(temp(size(a, dim=1))) + !$OMP PARALLEL DO PRIVATE(temp) COLLAPSE(2) + do i = 1, nb + do j = 1, na + temp(:) = a(:,j) - b(:,i) + k(j,i) = product(exp(-abs(temp / sigmas))) + enddo + enddo + !$OMP END PARALLEL DO + deallocate(temp) +end subroutine flaplacian_sigmas_kernel + subroutine flinear_kernel(a, na, b, nb, k) implicit none diff --git a/qml/kernels/kernels.py b/qml/kernels/kernels.py index 5eafc4fe3..162841126 100644 --- a/qml/kernels/kernels.py +++ b/qml/kernels/kernels.py @@ -24,10 +24,12 @@ import numpy as np -from .fkernels import fgaussian_kernel, fgaussian_kernel_symmetric -from .fkernels import flaplacian_kernel +from .fkernels import fgaussian_kernel from .fkernels import fgaussian_kernel_symmetric +from .fkernels import fgaussian_sigmas_kernel +from .fkernels import flaplacian_kernel from .fkernels import flaplacian_kernel_symmetric +from .fkernels import flaplacian_sigmas_kernel from .fkernels import flinear_kernel from .fkernels import fsargan_kernel from .fkernels import fmatern_kernel_l2 @@ -93,6 +95,32 @@ def laplacian_kernel_symmetric(A, sigma): return K +def laplacian_sigmas_kernel(A, B, sigmas): + """ Calculates the Laplacian kernel matrix K, where :math:`K_{ij}`: + + :math:`K_{ij} = \\prod_{k}^{S} \\big( -\\frac{\\|A_{ik} - B_{jk}\\|_1}{\sigma_{k}} \\big)` + + Where :math:`A_{i}` and :math:`B_{j}` are representation vectors and + :math:`S` is the size of representation vector. + K is calculated using an OpenMP parallel Fortran routine. + + :param A: 2D array of representations - shape (N, representation size). + :type A: numpy array + :param B: 2D array of representations - shape (M, representation size). + :type B: numpy array + :param sigmas: Per-feature values of sigma in the kernel matrix - shape (representation_size,) + :type sigmas: numpy array + + :return: The Laplacian kernel matrix - shape (N, M) + :rtype: numpy array + """ + na = A.shape[0] + nb = B.shape[0] + K = np.empty((na, nb), order='F') + # Note: Transposed for Fortran + flaplacian_sigmas_kernel(A.T, na, B.T, nb, K, sigmas) + return K + def gaussian_kernel(A, B, sigma): """ Calculates the Gaussian kernel matrix K, where :math:`K_{ij}`: @@ -148,6 +176,32 @@ def gaussian_kernel_symmetric(A, sigma): return K +def gaussian_sigmas_kernel(A, B, sigmas): + """ Calculates the Gaussian kernel matrix K, where :math:`K_{ij}`: + + :math:`K_{ij} = \\prod_{k}^{S} \\big( -\\frac{\\|A_{ik} - B_{jk}\\|_2^2}{2\sigma_{k}^{2}} \\big)` + + Where :math:`A_{i}` and :math:`B_{j}` are representation vectors and + :math:`S` is the size of representation vector. + K is calculated using an OpenMP parallel Fortran routine. + + :param A: 2D array of representations - shape (N, representation size). + :type A: numpy array + :param B: 2D array of representations - shape (M, representation size). + :type B: numpy array + :param sigmas: Per-feature values of sigma in the kernel matrix - shape (representation_size,) + :type sigmas: numpy array + + :return: The Gaussian kernel matrix - shape (N, M) + :rtype: numpy array + """ + na = A.shape[0] + nb = B.shape[0] + K = np.empty((na, nb), order='F') + # Note: Transposed for Fortran + fgaussian_sigmas_kernel(A.T, na, B.T, nb, K, sigmas) + return K + def linear_kernel(A, B): """ Calculates the linear kernel matrix K, where :math:`K_{ij}`: diff --git a/qml/representations/representations.py b/qml/representations/representations.py index 9cf89c360..bdd77b21d 100644 --- a/qml/representations/representations.py +++ b/qml/representations/representations.py @@ -325,24 +325,22 @@ def get_slatm_mbtypes(nuclear_charges, pbc='000'): :return: A list containing the types of many-body terms. :rtype: list """ - zs = nuclear_charges - nm = len(zs) zsmax = set() nas = [] zs_ravel = [] for zsi in zs: - na = len(zsi); nas.append(na) - zsil = list(zsi); zs_ravel += zsil - zsmax.update( zsil ) - - zsmax = np.array( list(zsmax) ) + na = len(zsi) + nas.append(na) + zsil = list(zsi) + zs_ravel += zsil + zsmax.update(zsil) + zsmax = np.array(list(zsmax)) nass = [] for i in range(nm): zsi = np.array(zs[i],np.int) - nass.append( [ (zi == zsi).sum() for zi in zsmax ] ) - + nass.append([(zi == zsi).sum() for zi in zsmax ]) nzmax = np.max(np.array(nass), axis=0) nzmax_u = [] if pbc != '000': @@ -353,25 +351,22 @@ def get_slatm_mbtypes(nuclear_charges, pbc='000'): nzi = 3 nzmax_u.append(nzi) nzmax = nzmax_u - - boas = [ [zi,] for zi in zsmax ] - bops = [ [zi,zi] for zi in zsmax ] + list( itl.combinations(zsmax,2) ) - + boas = [[zi,] for zi in zsmax] + bops = [[zi,zi] for zi in zsmax] + list(itl.combinations(zsmax,2)) bots = [] for i in zsmax: for bop in bops: - j,k = bop - tas = [ [i,j,k], [i,k,j], [j,i,k] ] + j, k = bop + tas = [[i,j,k], [i,k,j], [j,i,k]] for tasi in tas: if (tasi not in bots) and (tasi[::-1] not in bots): nzsi = [ (zj == tasi).sum() for zj in zsmax ] if np.all(nzsi <= nzmax): - bots.append( tasi ) + bots.append(tasi) mbtypes = boas + bops + bots - return mbtypes #, np.array(zs_ravel), np.array(nas) -def generate_slatm(coordinates, nuclear_charges, mbtypes, +def generate_slatm(coordinates, nuclear_charges, mbtypes=None, unit_cell=None, local=False, sigmas=[0.05,0.05], dgrids=[0.03,0.03], rcut=4.8, alchemy=False, pbc='000', rpower=6): """ @@ -406,11 +401,14 @@ def generate_slatm(coordinates, nuclear_charges, mbtypes, :rtype: numpy array """ + if mbtypes: + mbtypes = mbtypes + else: + mbtypes = get_slatm_mbtypes(nuclear_charges) c = unit_cell - iprt=False + iprt = False if c is None: - c = np.array([[1,0,0],[0,1,0],[0,0,1]]) - + c = np.array([[1,0,0], [0,1,0], [0,0,1]]) if pbc != '000': # print(' -- handling systems with periodic boundary condition') assert c != None, 'ERROR: Please specify unit cell for SLATM' @@ -419,41 +417,40 @@ def generate_slatm(coordinates, nuclear_charges, mbtypes, # info from db, we've already considered this point by letting maximal number # of nuclear charges being 3. # ======================================================================= - zs = nuclear_charges - na = len(zs) + N_atoms = len(zs) coords = coordinates - obj = [ zs, coords, c ] - - iloc = local - - if iloc: + obj = [zs, coords, c] + is_local = local + if is_local: mbs = [] X2Ns = [] - for ia in range(na): - # if iprt: print ' -- ia = ', ia + 1 - n1 = 0; n2 = 0; n3 = 0 + for atom_index in range(N_atoms): + # if iprt: print ' -- atom_index = ', atom_index + 1 + n1 = 0 + n2 = 0 + n3 = 0 mbs_ia = np.zeros(0) icount = 0 for mbtype in mbtypes: if len(mbtype) == 1: - mbsi = get_boa(mbtype[0], np.array([zs[ia],])) + mbsi = get_boa(mbtype[0], np.array([zs[atom_index],])) #print ' -- mbsi = ', mbsi if alchemy: n1 = 1 n1_0 = mbs_ia.shape[0] if n1_0 == 0: - mbs_ia = np.concatenate( (mbs_ia, mbsi), axis=0 ) + mbs_ia = np.concatenate((mbs_ia, mbsi), axis=0) elif n1_0 == 1: mbs_ia += mbsi else: raise '#ERROR' else: n1 += len(mbsi) - mbs_ia = np.concatenate( (mbs_ia, mbsi), axis=0 ) + mbs_ia = np.concatenate((mbs_ia, mbsi), axis=0) elif len(mbtype) == 2: #print ' 001, pbc = ', pbc - mbsi = get_sbop(mbtype, obj, iloc=iloc, ia=ia, \ + mbsi = get_sbop(mbtype, obj, iloc=is_local, ia=atom_index, \ sigma=sigmas[0], dgrid=dgrids[0], rcut=rcut, \ pbc=pbc, rpower=rpower) mbsi *= 0.5 # only for the two-body parts, local rpst @@ -472,9 +469,8 @@ def generate_slatm(coordinates, nuclear_charges, mbtypes, n2 += len(mbsi) mbs_ia = np.concatenate( (mbs_ia, mbsi), axis=0 ) else: # len(mbtype) == 3: - mbsi = get_sbot(mbtype, obj, iloc=iloc, ia=ia, \ + mbsi = get_sbot(mbtype, obj, iloc=is_local, ia=atom_index, \ sigma=sigmas[1], dgrid=dgrids[1], rcut=rcut, pbc=pbc) - if alchemy: n3 = len(mbsi) n3_0 = mbs_ia.shape[0] @@ -488,7 +484,6 @@ def generate_slatm(coordinates, nuclear_charges, mbtypes, else: n3 += len(mbsi) mbs_ia = np.concatenate( (mbs_ia, mbsi), axis=0 ) - mbs.append( mbs_ia ) X2N = [n1,n2,n3]; if X2N not in X2Ns: diff --git a/test/test_local_sigmas.py b/test/test_local_sigmas.py new file mode 100644 index 000000000..25d87cdb1 --- /dev/null +++ b/test/test_local_sigmas.py @@ -0,0 +1,158 @@ +# MIT License +# +# Copyright (c) 2017-2019 Anders Steen Christensen, Jakub Wagner +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import print_function + +import sys +import os +import numpy as np +import scipy + +import qml +from qml.helpers import get_BoB_groups, compose_BoB_sigma_vector +from qml.kernels import gaussian_kernel, laplacian_kernel +from qml.kernels import gaussian_sigmas_kernel, laplacian_sigmas_kernel + +def test_indices_getter(): + """ + Test if indices of BoB groups are correctly returned. + """ + asize = {"C":1, "O":2, "H":2} + asize = {k: asize[k] for k in sorted(asize, key=asize.get)} # Sort asize + correct_low_indices = {'CC': 0, 'CO': 1, 'CH': 3, 'OO': 5, 'OH': 8, 'HH': 12} + correct_high_indices = {'CC': 1, 'CO': 3, 'CH': 5, 'OO': 8, 'OH': 12, 'HH': 15} + low_indices, high_indices = get_BoB_groups(asize) + assert low_indices == correct_low_indices + assert high_indices == correct_high_indices + asize = {"O":5, "C":9, "N":7, "H":20, "F":6} + asize = {k: asize[k] for k in sorted(asize, key=asize.get)} # Sort asize + correct_low_indices = {'OO': 0 , 'OF': 15 , 'ON': 45, 'OC': 80, 'OH': 125, + 'FF': 225, 'FN': 246, 'FC': 288, 'FH': 342, + 'NN': 462, 'NC': 490, 'NH': 553, + 'CC': 693, 'CH': 738, + 'HH': 918} + correct_high_indices = {'OO': 15, 'OF': 45, 'ON': 80, 'OC': 125, 'OH': 225, + 'FF': 246, 'FN': 288, 'FC': 342, 'FH': 462, + 'NN': 490, 'NC': 553, 'NH': 693, + 'CC': 738, 'CH': 918, + 'HH': 1128} + low_indices, high_indices = get_BoB_groups(asize) + assert low_indices == correct_low_indices + assert high_indices == correct_high_indices + +def test_sigma_vector_composition(): + """ + Test if vector of per-feature sigmas is correctly composed. + """ + asize = {"C":1, "O":2, "H":2} + asize = {k: asize[k] for k in sorted(asize, key=asize.get)} # Sort asize + low_indices, high_indices = get_BoB_groups(asize) + sigmas_for_bags = {'CC': 1, 'CO': 2, 'CH': 3, 'OO': 4, 'OH': 5, 'HH': 6} + sigmas = compose_BoB_sigma_vector(sigmas_for_bags, low_indices, high_indices) + correct_sigmas = np.array([1., 2., 2., 3., 3., 4., 4., 4., 5., 5., 5., 5., 6., 6., 6.]) + assert np.allclose(sigmas, correct_sigmas) + +def test_gaussian_sigmas_kernel(): + """ + Test if Gaussian kernel with per-feature sigmas work correctly. + """ + np.random.seed(666) + n_train = 25 + n_test = 20 + n_features = 1000 + # List of dummy representations + X_train = np.random.rand(n_train, n_features) + X_test = np.random.rand(n_test, n_features) + sigmas = np.random.rand(n_features) + K_test = np.ones((n_train, n_test)) + for i in range(n_train): + for j in range(n_test): + for k in range(n_features): + K_test[i,j] *= np.exp(-np.abs((X_train[i,k]-X_test[j,k])**2/(2.0*sigmas[k]**2))) + K = gaussian_sigmas_kernel(X_train, X_test, sigmas) + assert np.allclose(K, K_test) + +def test_laplacian_sigmas_kernel(): + """ + Test if Laplacian kernel with per-feature sigmas work correctly. + """ + np.random.seed(666) + n_train = 25 + n_test = 20 + n_features = 1000 + # List of dummy representations + X_train = np.random.rand(n_train, n_features) + X_test = np.random.rand(n_test, n_features) + sigmas = np.random.rand(n_features) + K_test = np.ones((n_train, n_test)) + for i in range(n_train): + for j in range(n_test): + for k in range(n_features): + K_test[i,j] *= np.exp(-np.abs((X_train[i,k]-X_test[j,k])/(sigmas[k]))) + K = laplacian_sigmas_kernel(X_train, X_test, sigmas) + assert np.allclose(K, K_test) + +def test_single_sigma_gaussian(): + """ + Test if gaussian_sigmas_kernel gives the same result as gaussian_kernel if + all local sigmas are set to the global sigma. + """ + np.random.seed(666) + n_train = 25 + n_test = 20 + # List of dummy representations + X_train = np.random.rand(n_train, 1000) + X_test = np.random.rand(n_test, 1000) + global_sigma = 1.0 + sigmas = np.ones(1000)*global_sigma + K = gaussian_sigmas_kernel(X_train, X_test, sigmas) + K_test = gaussian_kernel(X_train, X_test, global_sigma) + assert np.allclose(K, K_test) + K_symm = gaussian_sigmas_kernel(X_train, X_train, sigmas) + assert np.allclose(K_symm, K_symm.T) + +def test_single_sigma_laplacian(): + """ + Test if laplacian_sigmas_kernel gives the same result as laplacian_kernel + if all local sigmas are set to the global sigma. + """ + np.random.seed(666) + n_train = 25 + n_test = 20 + # List of dummy representations + X_train = np.random.rand(n_train, 1000) + X_test = np.random.rand(n_test, 1000) + global_sigma = 1.0 + sigmas = np.ones(1000)*global_sigma + K = laplacian_sigmas_kernel(X_train, X_test, sigmas) + K_test = gaussian_kernel(X_train, X_test, global_sigma) + assert np.allclose(K, K_test) + K_symm = laplacian_sigmas_kernel(X_train, X_train, sigmas) + assert np.allclose(K_symm, K_symm.T) + +if __name__ == "__main__": + test_indices_getter() + test_sigma_vector_composition() + test_gaussian_kernel() + test_laplacian_kernel() + test_single_sigma_gaussian() + test_single_sigma_laplacian()