Search Blogs

Thursday, May 16, 2024

Update on Crystal Graph CNN

I've finally made progress! Thanks1 to a shift from using GeometricFlux.jl to GraphNeuralNetworks.jl. Both are good, but I find GraphNeuralNetworks.jl to be easier to use. I wrote about my interest in doing something with GNNs originally in this blog post and then a few other general posts.

If you recall from the original post was on Crystal Graph CNN for property predictions. The original paper was by Xie and Grossman [1] and is in my opinion a very easy paper to follow and understand, hence why I choose it. I also think it set off a bit of the interest of GNNs in materials science. The nice thing is that paper used the materials project data, which I'm familiar with.

Original CGCNN Dataset

One thing I've noticed is that the materials project ids in the .csv files in the original CGCNN repo have ids that are no longer valid to query the materials project API. I'm not sure if this is an issue with my REST API julia function or that the materials project drops ids for structures when they get more accurate calculations for those structures and chemical systems. My guess is the REST API endpoints that I am using are wrong and I need to make modifications.

So how did I make progress? Well I finally got all the innards working to prepare the graphs. The final Julia code to do that looks like:

""" get_item(cifdata::CIFData, idx; dtype=Float32)
Prepare a graph for a Graph Neural Network (GNN) from CIF data.
# Arguments - cifdata::CIFData: An instance of CIFData containing crystallographic information. - idx: An index identifying the specific crystal structure within cifdata. - dtype: The data type for the node and edge features (default is Float32).
# Returns - gnn_graph::GNNGraph: A graph object compatible with GNNs, containing node and edge features.
# Details This function performs the following steps: 1. Loads the crystal structure from a .cif file. 2. Extracts atomic numbers and computes node features. 3. Initializes a graph with nodes corresponding to atoms in the crystal. 4. Builds a neighbor list and constructs the adjacency matrix. 5. Computes edge features using an expanded Gaussian Radial Basis Function (RBF). 6. Creates a GNNGraph object with the node and edge features. """ function get_item(cifdata::CIFData, idx; dtype=Float32) cif_id, target = cifdata.id_prop_data[idx] crystal = load_system(joinpath(cifdata.root_dir, join([cif_id, ".cif"])))
# Atom/Node Features at_nums = @. atomic_number(crystal) atom_features = [dtype.(get_atom_fea(cifdata.ari, at)) for at in at_nums] node_features = hcat(atom_features...) # Convert to matrix form
num_atoms = length(crystal) g = SimpleGraph(num_atoms)
# Build Neighbor list -> Adjacency Matrix nlist = PairList(crystal, cifdata.radius*u"Å") adj_mat = construct_adjacency_mat(nlist) edge_features = [] processed_edges = Set{Tuple{Int, Int}}() # Construct edge features in expanded Gaussian RBF # 1. We have to track whether a edge/pair has been # assigned a feature vector. # 2. The implementation here differs from CGCNN.py # in that we don't process all edge features for # a node/atom in one shot. for i in 1:num_atoms nbrs, nbrs_dist_vecs = neigs(nlist, i) for (j, dist_vec) in zip(nbrs, nbrs_dist_vecs) edge = i < j ? (i, j) : (j, i) # Ensure unique representation if edge processed_edges add_edge!(g, edge[1], edge[2]) # Add edge to the graph dist = norm(dist_vec) dist_basis = expand(dist, cifdata.gdf) push!(edge_features, dtype.(dist_basis[:])) push!(processed_edges, edge) end end end edge_feature_matrix = hcat(edge_features...)
gnn_graph = GNNGraph(g; ndata=node_features, edata=edge_feature_matrix)
return gnn_graph end

There are a lot of supporting function that are implemented but are not shown here. Hopefully from the doc string and function naming you can see what they do. Once I figured out how to do this it was pretty simple to construct a neural network with different convolutional, pooling, and dense layers. I haven't fully implemented the correct architecture as described in [1], so I won't show the code here yet. The model I do have isn't doing to well given the MAE loss curves below (~1K data points), but this is not the model in the paper. Also don't have the compute to deal with the total dataset size (~70K).

I can at least start from here! Obviously it looks pretty bad.

If your looking to utilize the function above, just shoot me an email and I'll share the current Pluto notebook. The notebook itself will be a post on my computational blog once I've successfully reproduced, within some reasonable error, a metric from ref. [1]. I don't know what the timeline looks like for this because it will depend on my ability to train the model on my computatational resources. At the moment I don't have a GPU to train on.

If your looking to utilize the function above, just shoot me an email and I'll share the current Pluto notebook. The notebook itself will be a post on my computational blog once I've successfully reproduced, within some reasonable error, a metric from ref. [1]. I don't know what the timeline looks like for this because it will depend on my ability to train the model on my computatational resources. At the moment I don't have a GPU to train on.

Training Update

I think I've managed to get the correct model representation with Flux.jl, I'm not getting training and validation curves that look more reasonable.

Training curves looking better, still need to improve things.

Footnotes


  1. Also thanks to the github user @aurorarossi who provided some basic understanding on how GNNGraph treats undirected graphs as having two opposing directed edges for each node pair. 


References

[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.



Reuse and Attribution

No comments:

Post a Comment

Please refrain from using ad hominem attacks, profanity, slander, or any similar sentiment in your comments. Let's keep the discussion respectful and constructive.