Search Blogs

Thursday, May 30, 2024

What am I getting wrong here

The implementation of Xie and Grossman's crystal graph convolutional neural network[1] in Julia using Flux.jl and GraphNeuralNetworks.jl [2] is moving along. As I updated in my previous post, I got the data wrangling and processing done correctly. Well, at least I think. It is hard to compare apples-to-apples with the Python implementation and my Julia code for the graph structure. In the original PyTorch version by Xie and Grossman, the graph structure was just conceptual as there were no class objects that specifically represented a graph. Actually, it would be useful if someone familiar with PyTorch Geometric, CGCNN, and Flux.jl/GraphNeuralNetworks.jl [2] could check things out for me. 🙏

But assuming the graph structure of the data is correct, the thing that is even more difficult is the construction of the neural network layers. In the original implementation, the node features and edge features are concatenated into a single feature vector on the nodes, then a convolutional weight matrix (i.e., kernel) is applied to "reshape" the feature vector and update the node features. Let me back up; the concatenated edge features are taken by a summed aggregation over nodes/edges. This means we take the bonds based on some cutoff to neighboring atoms. In the parlance of GNN, I think this is just a message passing scheme. But I'm not entirely sure.

So we have an update that takes the edge features, concatenates them to the node features, then applies a convolutional weight matrix. Then we also include a self-update to the original node feature using a weight matrix. This type of operation updates the node features, but not the edge features. This may make sense because in the original CGCNN implementation, the edge features are represented as radial basis functions. To summarize, this is the equation:

$$ \begin{align} z^{(t)}_{(i,j)_k} &= v_{i}^{(t)} \oplus v_{j}^{(t)} \oplus u_{(i,j)_k} \label{eq:feature} \\ v_i^{(t+1)} &= v_i^{(t)} + \sum_{j,k} \sigma\left(z^{(t)}_{(i,j)_k} \mathbf{W}_f^{(t)} + \mathbf{b}_f^{(t)} \right) \odot g\left(z^{(t)}_{(i,j)_k} \mathbf{W}_s^{(t)} + \mathbf{b}_s^{(t)} \right) \label{eq:update} \end{align} $$

Here we have $t$ being the layer1 number. The key aspect is that the convolutional weight matrix, $\mathbf{W}^{(t)}_f$, is not globally learned, but rather learned for $i$-th atom environment over $k$ bonds2. I'm pretty sure my current implementation does not do this and does what is described in eq. 4 in the original paper:

$$ \begin{align} \mathbf{v}_i\left(t+1\right) = g\Bigg[ &\left(\sum_{i,j} \mathbf{v}_j\left(t\right) \oplus \mathbf{u}_{\left(i,j\right)_k}\right)\mathbf{W}_c\left(t\right) \nonumber \\ &+ \mathbf{v}_i(t)\;\mathbf{W}_s\left(t\right) + \mathbf{b}\left(t\right)\Bigg]\label{eq:Xie_Grossman_eq4} \\ \end{align} $$

The other confusion I have is that the pooling operation appears to be done across each layer update of the node features. This makes sense but I don't think I'm doing it correctly.

Here is my current CGCNN model Flux.jl implementation:

# Define the CGCNN struct struct CGCNN embedding::Dense convs::Vector{CGConv} conv_to_fc::Dense conv_to_fc_softplus::Function fc_out::Dense end
# Constructor for CGCNN ...
# Pooling function function pooling(atom_fea) return sum(atom_fea, dims=2) ./ size(atom_fea, 2) end
# Forward pass function (model::CGCNN)(g::GNNGraph) atom_fea = g.ndata[:x] edge_fea = g.edata[:e]
atom_fea = model.embedding(atom_fea) for conv in model.convs atom_fea = conv(g, atom_fea, edge_fea) end crys_fea = pooling(atom_fea) crys_fea = model.conv_to_fc_softplus(model.conv_to_fc(crys_fea))
out = model.fc_out(crys_fea)
return out end

The concern is in the pooling operation as you see it's applied only to the final layer update to the node/atom features, which is not what is shown in eq. 2 in Xie and Grossman's paper. The thing though is Zygote.jl which does the automatic differentiation for back-propagation, fails about mutating arrays if I try and store the atom_fea outputs at each Graph CNN layer.

Well, I will keep working through this until I get something that seems to train at the same mean-absolute-error as the original paper. Also, the softmax should be applied to each feature layer, $v_i^{(t)}$ prior to pooling, which I'm not doing (see eq. S1 in [1]).

It would have been nice if a detailed neural network diagram was shown on top of Fig. 1 in Xie and Grossman's paper. I'm probably missing some small detail regarding implementation. The one thing I know is correct is the conv function call because this is just CGConv from GraphNeuralNetworks.jl which is an exact implementation of eq. 5 in [1].

Footnotes


  1. I think it is easier to think of $t$ as the iteration step because what is happening is the node features are updating at each layer, not the edges though. 

  2. If you look at Fig. S1 in Xie and Grossman's paper, you see that they actually represent the graph using symmetric bonds and thus two nodes in their representation have multiple edges/bonds. This is why the $k$ index is used. In my implementation, I just build an adjacency matrix and use this to define the graph. 


References

[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.

[2] C. Lucibello, other contributors, GraphNeuralNetworks.jl: a geometric deep learning library for the Julia programming language, (2021). https://github.com/CarloLucibello/GraphNeuralNetworks.jl.



Reuse and Attribution

No comments:

Post a Comment

Please refrain from using ad hominem attacks, profanity, slander, or any similar sentiment in your comments. Let's keep the discussion respectful and constructive.