Search Blogs

Thursday, May 16, 2024

Update on Crystal Graph CNN

I've finally made progress! Thanks1 to a shift from using GeometricFlux.jl to GraphNeuralNetworks.jl. Both are good, but I find GraphNeuralNetworks.jl to be easier to use. I wrote about my interest in doing something with GNNs originally in this blog post and then a few other general posts.

If you recall from the original post was on Crystal Graph CNN for property predictions. The original paper was by Xie and Grossman [1] and is in my opinion a very easy paper to follow and understand, hence why I choose it. I also think it set off a bit of the interest of GNNs in materials science. The nice thing is that paper used the materials project data, which I'm familiar with.

Original CGCNN Dataset

One thing I've noticed is that the materials project ids in the .csv files in the original CGCNN repo have ids that are no longer valid to query the materials project API. I'm not sure if this is an issue with my REST API julia function or that the materials project drops ids for structures when they get more accurate calculations for those structures and chemical systems. My guess is the REST API endpoints that I am using are wrong and I need to make modifications.

So how did I make progress? Well I finally got all the innards working to prepare the graphs. The final Julia code to do that looks like:

""" get_item(cifdata::CIFData, idx; dtype=Float32)
Prepare a graph for a Graph Neural Network (GNN) from CIF data.
# Arguments - cifdata::CIFData: An instance of CIFData containing crystallographic information. - idx: An index identifying the specific crystal structure within cifdata. - dtype: The data type for the node and edge features (default is Float32).
# Returns - gnn_graph::GNNGraph: A graph object compatible with GNNs, containing node and edge features.
# Details This function performs the following steps: 1. Loads the crystal structure from a .cif file. 2. Extracts atomic numbers and computes node features. 3. Initializes a graph with nodes corresponding to atoms in the crystal. 4. Builds a neighbor list and constructs the adjacency matrix. 5. Computes edge features using an expanded Gaussian Radial Basis Function (RBF). 6. Creates a GNNGraph object with the node and edge features. """ function get_item(cifdata::CIFData, idx; dtype=Float32) cif_id, target = cifdata.id_prop_data[idx] crystal = load_system(joinpath(cifdata.root_dir, join([cif_id, ".cif"])))
# Atom/Node Features at_nums = @. atomic_number(crystal) atom_features = [dtype.(get_atom_fea(cifdata.ari, at)) for at in at_nums] node_features = hcat(atom_features...) # Convert to matrix form
num_atoms = length(crystal) g = SimpleGraph(num_atoms)
# Build Neighbor list -> Adjacency Matrix nlist = PairList(crystal, cifdata.radius*u"Å") adj_mat = construct_adjacency_mat(nlist) edge_features = [] processed_edges = Set{Tuple{Int, Int}}() # Construct edge features in expanded Gaussian RBF # 1. We have to track whether a edge/pair has been # assigned a feature vector. # 2. The implementation here differs from CGCNN.py # in that we don't process all edge features for # a node/atom in one shot. for i in 1:num_atoms nbrs, nbrs_dist_vecs = neigs(nlist, i) for (j, dist_vec) in zip(nbrs, nbrs_dist_vecs) edge = i < j ? (i, j) : (j, i) # Ensure unique representation if edge processed_edges add_edge!(g, edge[1], edge[2]) # Add edge to the graph dist = norm(dist_vec) dist_basis = expand(dist, cifdata.gdf) push!(edge_features, dtype.(dist_basis[:])) push!(processed_edges, edge) end end end edge_feature_matrix = hcat(edge_features...)
gnn_graph = GNNGraph(g; ndata=node_features, edata=edge_feature_matrix)
return gnn_graph end

There are a lot of supporting function that are implemented but are not shown here. Hopefully from the doc string and function naming you can see what they do. Once I figured out how to do this it was pretty simple to construct a neural network with different convolutional, pooling, and dense layers. I haven't fully implemented the correct architecture as described in [1], so I won't show the code here yet. The model I do have isn't doing to well given the MAE loss curves below (~1K data points), but this is not the model in the paper. Also don't have the compute to deal with the total dataset size (~70K).

I can at least start from here! Obviously it looks pretty bad.

If your looking to utilize the function above, just shoot me an email and I'll share the current Pluto notebook. The notebook itself will be a post on my computational blog once I've successfully reproduced, within some reasonable error, a metric from ref. [1]. I don't know what the timeline looks like for this because it will depend on my ability to train the model on my computatational resources. At the moment I don't have a GPU to train on.

Footnotes


  1. Also thanks to the github user @aurorarossi who provided some basic understanding on how GNNGraph treats undirected graphs as having two opposing directed edges for each node pair. 


References

[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.



Reuse and Attribution

Thursday, May 9, 2024

KANs and AlphaFold 3

Taking a break from my posts and reading on thermodynamic computing just earmarking here a lot of interesting things that are going on in the ML/AI world.

Kolmogorov-Arnold networks (KAN)

Last week there was a preprint about a new deep learning architecture called Kolmogorov-Arnold networks(KAN). These are in contrast to multi-layer perceptron (MLP). What is really interesting is that instead of learning the linear weights on edges, $\mathbf{W}_i$1, of a layer:

$$ \sigma(\mathbf{W}_i\cdot\mathbf{x}_{i-1} + \mathbf{b}_i) $$,

that are passed to a fix non-linear activation function, $\sigma$, the network learns the parameterized non-linear activation function themselves. The learning of the non-linear functions is done on the edges and then a summation operation on the nodes (i.e., where the activation function in MLPs would normally be done) is performed. Mathematically for the interlayer (i.e., edge operation) you would have:

$$ \begin{equation} \Phi(\sum_i^n \phi_i(x,\mathbf{\alpha})) \label{eq:KA} \end{equation} $$

here the $i$'s are the edges the functions live-on that link node in between layers. The key point is $\phi_i$ are the learned non-linear functions with parameters $\mathbf{\alpha}$2. In the paper they use b-splines to learn the functions $\phi_i$. My guess is they choose these over other basis functions because b-splines can be locally controlled and are smooth, but there should be no reason one can't use other functions.

I didn't state why this would work, it turns out that KAN are based on the Kolmogorov-Arnold representation theorem which is a connivent way to represent real-valued multivariate functions as a sum of univariate functions over the domain $[0,1]$. I did not know about this!

This is my best understanding of the KAN paper, but I've barely scratched the surface of the 48 pages. The other thing is that the function composition in eq. $\eqref{eq:KA}$ is represented in KANs as a matrix operations, i.e., $\text{KAN}(\mathbf{x})= (\Phi_{N} \;\circ \Phi_{N-1} \; \circ \;\dots \;\circ \Phi_1)(\mathbf{x})$, so I think its just a linear function matrix (see eq. 2.6 in [1]). I need to understand better.

If I am getting this wrong above please drop a comment. The graphical abstract covers all what I tried to convey above better:

Abstract graphic for KANs from [1].

The thing that interests me is that KANs actually seem more intuitive. The network learns the representative non-linear functions that capture characteristics of the data. I guess a MLPs do this as well but less effective? Well at least based on the paper which shows a much better scaling law compared to MLPs.

I don't think I'll have much time to focus more on this but I am very curious about the use of KANs over Gated MLPs in things like Graph Neural Networks. The authors of the preprint put together a very nice and usable python package, pykan.

Update 16 May, 2024

The author of ref. [1] posted a review video of the paper.

AlphaFold 3

Now onto the field of computational biochemistry and molecular biology. Google's Deepmind just published there research results for AlphaFold 3 [2]. Basically, it appears they have the ability to develop and discover new drugs using computational means as well as address other life science challenges. Google is probably sitting on a very lucrative resource, so kudos them for investing in it.

The thing thats a little bit off, is how all these big tech companies are doing research and publishing in science journals yet no one has access or knowledge into exactly how these models work and are trained. Its not bad that they want to keep the information internally but they probably shouldn't publish theb. I should make a blog post on this topic about how big tech scientist are now in the business of publishing papers; just a new kind era I guess.

I should mention that they did make available a compute server where non-commercial users can make queries with AlphaFold 3.

Footnotes


  1. I'll use $i$ to represent the layer edges and therefore the inputs $\mathbf{x}$ will be $i-1$, but this isn't standard notation. 

  2. This is totally different from how its written in ref. [1], but it helps me understand; at least I think 😏. 


References

[1] Z. Liu, Y. Wang, S. Vaidya, F. Ruehle, J. Halverson, M. Soljačić, T.Y. Hou, M. Tegmark, KAN: Kolmogorov-Arnold Networks, (2024). arXiv.

[2] Google DeepMind and Isomorphic Labs, Accurate structure prediction of biomolecular interactions with AlphaFold 3, Nature (2024) 1–3. https://doi.org/10.1038/s41586-024-07487-w.



Reuse and Attribution