ChemistryFeaturization Usage Example: Featurizing AtomGraphs

This tutorial will demonstrate usage of ChemistryFeaturization.jl for a case that has already implemented the interface. The structure representation is AtomGraph, provided by the AtomGraphs package, and the featurization is GraphNodeFeaturization from AtomicGraphNets. We will choose some ElementFeatureDescriptors by which to featurize each node of our AtomGraph object and step through how to do it.

using ChemistryFeaturization, AtomGraphs, AtomicGraphNets, PlutoUI
#! format: off
TableOfContents()

Creating AtomGraph objects

AtomGraph objects are simply adjacency matrices with elemental symbols labeling each node. We can create one "from scratch" by manually specifying an adjancey matrix, like so...

adj_mat = Float32.([0 1 1; 1 0 1; 1 1 0])
3×3 Matrix{Float32}:
 0.0  1.0  1.0
 1.0  0.0  1.0
 1.0  1.0  0.0
triangle_C = AtomGraph(adj_mat, ["C", "C", "C"])
AtomGraphs.AtomGraph{SimpleWeightedGraphs.SimpleWeightedGraph{Int64, Float32}}  with 3 nodes, 3 edges
	atoms: ["C", "C", "C"]

The built-in visualize function in AtomGraphs.jl doesn't work in Pluto notebooks, so we'll create a slightly modified one that will display inline properly and visualize the graph.

begin
	using GraphPlot, Graphs
	
	function plutoviz(ag::AtomGraph)
		sg = SimpleGraph(ag.graph.weights)
		gplot(
        	sg,
        	nodefillc = AtomGraphs.graph_colors(elements(ag)),
        	nodelabel = elements(ag),
        	edgelinewidth = AtomGraphs.graph_edgewidths(ag),
    	)
	end
end
plutoviz (generic function with 1 method)
plutoviz(triangle_C)
C C C

Of course, in practice we'll more likely be reading in structures and building graphs from files such as .cif, .xyz, etc. Here, we'll read in the structure of WS<sub>2</sub>, downloaded from the Materials Project:

WS2 = AtomGraph("../files/mp-224.cif")
AtomGraphs.AtomGraph{Xtals.Crystal} ../files/mp-224 with 6 nodes, 6 edges
	atoms: ["W", "W", "S", "S", "S", "S"]

The graph is automatically assigned an id based on the filename it was read from, but you can pass a value to override this and name it something else.

We can visualize it just like we did before:

WS2
AtomGraphs.AtomGraph{Xtals.Crystal} ../files/mp-224 with 6 nodes, 6 edges
	atoms: ["W", "W", "S", "S", "S", "S"]
plutoviz(WS2)
W W S S S S

Interesting! It seems this graph has two disconnected components. This isn't too surprising if we look at the 3D structure of this compound:

begin
	using ImageShow, Images
	load("../files/mp-224.png")
end

It's a two-dimensional material with two formula units per unit cell! Another way to see the disconnectedness of the graph is to index into the adjacency matrix in a particularly illustrative order:

WS2.graph[[1,4,6,2,3,5]].weights
6×6 SparseArrays.SparseMatrixCSC{Float64, Int64} with 12 stored entries:
  ⋅   1.0      1.0       ⋅    ⋅        ⋅ 
 1.0   ⋅       0.19762   ⋅    ⋅        ⋅ 
 1.0  0.19762   ⋅        ⋅    ⋅        ⋅ 
  ⋅    ⋅        ⋅        ⋅   1.0      1.0
  ⋅    ⋅        ⋅       1.0   ⋅       0.19762
  ⋅    ⋅        ⋅       1.0  0.19762   ⋅ 

However, we have options in how we actually construct the graph. The default option is based on the scheme from the original cgcnn.py implementation, which essentially involves setting a maximum neighbor distance and a maximum number of neighbors. However, in contrast to that implementation, we construct weighted graphs (with the user having an ability to specify the weight decay function with separation distance; it defaults to inverse-square).

An arguably more physical way to construct neighbor lists and graphs is by doing a Voronoi partition of the atomic coordinates. In this scheme, the neighbor list of an atom is any atom with which its Voronoi polyhedron shares a face, and the edge weights can be determined using the areas of the faces. Let's try that with our WS<sub>2</sub> structure...

WS2_v = AtomGraph(joinpath("..", "files", "mp-224.cif"), use_voronoi=true)
AtomGraphs.AtomGraph{PyCall.PyObject} ../files/mp-224 with 6 nodes, 14 edges
	atoms: ["W", "W", "S", "S", "S", "S"]
WS2_v.graph[[1,4,6,2,3,5]].weights
6×6 SparseArrays.SparseMatrixCSC{Float64, Int64} with 22 stored entries:
 0.371678  0.970448   0.970448    ⋅         ⋅          ⋅ 
 0.970448  1.0        0.0231855   ⋅         ⋅         0.31894
 0.970448  0.0231855  1.0         ⋅        0.31894     ⋅ 
  ⋅         ⋅          ⋅         0.371678  0.970448   0.970448
  ⋅         ⋅         0.31894    0.970448  1.0        0.0231855
  ⋅        0.31894     ⋅         0.970448  0.0231855  1.0

Batch Processing

One final note for this section: the AtomGraph constructor broadcasts! So if you have a directory full of structure files (say, strucs/), you can get a list of AtomGraph objects by:

ags = AtomGraph.(readdir("strucs/", join=true))

Building and Encoding Feature Descriptors

What types of features of our structure do we want to encode in our graph? Let's keep things simple for now and consider features that can be encoded only by knowing the elemental identity of a given atom (node in our graph). The package includes a bunch of built-in data, and you can also provide your own for features we haven't included!

We can easily construct these for built-in features...

using ChemistryFeaturization.ElementFeature

Categorical features

Let's start with a categorical feature, that is, one that takes on a finite set of discrete values. One example of this is which block (s, p, d, or f) in the periodic table an element resides in.

block = ElementFeatureDescriptor("Block") # categorical feature denoting s-, p-, d-, or f-block elements
ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Block

We can get the values of the feature for a given structure by "calling" it directly...

block(triangle_C)
3-element Vector{InlineStrings.String1}:
 "p"
 "p"
 "p"

...or by using the get_value function.

get_value(block, WS2)
6-element Vector{InlineStrings.String1}:
 "d"
 "d"
 "p"
 "p"
 "p"
 "p"

Of course, vectors of single characters are not going to be all that useful to feed into a machine learning model. To "translate" these human-readable values, we need to encode them. For this, we'll use a codec object. Codec is short for "encoder-decoder", reflecting a key design principle of ChemistryFeaturization that we should always know what information we've encoded and be able to invert that encoding process.

A common method of encoding categorical-valued featuers is using so-called "one-hot" encoding. In ChemistryFeaturization, this is implemented via the OneHotOneCold codec. We can retrieve a "sensible" one for our block feature using the default_codec function...

block_codec = default_codec(block)
OneHotOneCold(true, InlineStrings.String1["d", "f", "p", "s"])

So what is this OneHotOneCold thing? It may be easiest to see by using it. We can encode single values...

d_encoded = encode("d", block_codec)
4-element Vector{Float64}:
 1.0
 0.0
 0.0
 0.0
s_encoded = encode("s", block_codec)
4-element Vector{Float64}:
 0.0
 0.0
 0.0
 1.0

...or even whole structures...

WS2_block_encoded = encode(WS2, block)
6×4 Matrix{Float64}:
 1.0  0.0  0.0  0.0
 1.0  0.0  0.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  1.0  0.0
 0.0  0.0  1.0  0.0

The output for a single value is a bitstring of 0's with a 1 in the "slot" corresponding to the value, where the "slots" are specified by the vals field of the codec, in this case ["d", "f", "p", "s"]. For a structure like our WS2 AtomGraph, these vectors are concatenated into a matrix where the first index is the index of the atom and the second indexes into the feature vector. Note that we could have called encode(WS2, block, block_codec) above for the same result, but encode will call default_codec by default so it's not necessary.

Note that we can always decode what we encode...(and ChemistryFeaturization will also internally call default_codec so if you're using the defaults, you can feed either the codec or the feature descriptor)

decode(s_encoded, block_codec)
"s"
decode(WS2_block_encoded, block)
6-element Vector{Union{Missing, InlineStrings.String1}}:
 "d"
 "d"
 "p"
 "p"
 "p"
 "p"

So what's that other field in the codec that has a value of true? Read on...

Continuous-valued features

Some features, such as atomic mass, are better described as continua of values. One example might be the mass of an atom, another built-in ElementFeatureDescriptor:

amass = ElementFeatureDescriptor("Atomic mass") # continuous-valued feature
ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Atomic mass
amass_codec = default_codec(amass)
OneHotOneCold(false, [1.00794, 1.757605251627154, 3.0648413799902294, 5.344347188200727, 9.319257777745117, 16.250547067714365, 28.33705068558546, 49.413009802775825, 86.16442003300078, 150.25005174257413, 261.9999999999999])

Note that this time that first flag is false. The flag describes whether the codec is describing a categorical or continuous-valued feature, and this influences how the bins are interpreted. For categorical features, each value in bins corresponds to a possible value of the feature. For continuous features, bins represents N+1 edges of N bins into which we've divided the possible values of this feature.

In addition, in this case of atomic mass, the bins are logarithmically spaced by default. For more on how to tune these defaults, check out the onehotonecold_utils.jl source file.

Note a further consequence of this difference in the interpretation of bins for categorical vs. continuous features:

length.([block_codec.bins, amass_codec.bins])
2-element Vector{Int64}:
  4
 11
output_shape.([block_codec, amass_codec])
2-element Vector{Int64}:
  4
 10

(Beware OBO errors all ye who enter here.)

Let's try encoding/decoding some values!

amass(triangle_C)
3-element Vector{Float64}:
 12.0107
 12.0107
 12.0107
triangle_C_amass_encoded = encode(triangle_C, amass)
3×10 Matrix{Float64}:
 0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0
decode(triangle_C_amass_encoded, amass)
3-element Vector{Union{Missing, Tuple{Float64, Float64}}}:
 (9.319257777745117, 16.250547067714365)
 (9.319257777745117, 16.250547067714365)
 (9.319257777745117, 16.250547067714365)

Oooh, interesting, what's happened here?

Turns out one-hot encoding loses some information for continuous-valued features, so instead of getting back carbon's atomic mass of 12.01, we can back tuples indicating the edges of the bins. If you want to encode with higher resolution, you can build a codec with more bins!

amass_codec_hires = OneHotOneCold(false, get_bins(amass_codec.bins, nbins=16, logspaced=true))
OneHotOneCold(false, [1.00794, 1.4268024409661881, 2.019728560774523, 2.859052761674519, 4.04716893784186, 5.729022084167015, 8.109790953865007, 11.47991897206193, 16.250547067714365, 23.003671074915946, 32.563142687931204, 46.09517577700111, 65.25061939737554, 92.36635417855338, 130.75038157539873, 185.0853856271599, 261.9999999999999])
decode(encode(triangle_C, amass, amass_codec_hires), amass_codec_hires)
3-element Vector{Union{Missing, Tuple{Float64, Float64}}}:
 (11.47991897206193, 16.250547067714365)
 (11.47991897206193, 16.250547067714365)
 (11.47991897206193, 16.250547067714365)

Custom features

But suppose you have another feature that's not included. You can easily provide a lookup table (or even an entire custom encoding function!) yourself, like so...

begin
	using DataFrames
	lookup_table = DataFrame(["C" 42; "As" 0], [:Symbol, :MeaningOfLife]); # make a custom lookup table for another feature
	meaning = ElementFeatureDescriptor("MeaningOfLife", lookup_table)
end
ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor MeaningOfLife

Once you've done this, you can use this feature the same way you would use the built-in ones. Just make sure you're okay with the default decisions the package makes about whether to treat the feature as categorical vs. continuous and space the bins linearly or logarithmically (again, see src/codecs/onehotonecold_utils.jl for details) and tweak the flags if you're not.

Building a featurization

Next, we can combine these feature descriptors into a featurization object, which allows convenient encoding of multiple features on a structure, and also combining of those encoded features in a manner appropriate for feeding into a model. In the case of GraphNodeFeaturization, we construct a vector for each node in an AtomGraph by concatenating encoded features together, and then stack these vectors to form a feature matrix that we could feed into an AtomicGraphNets model.

This featurization has a convenience constructor that will build the ElementFeatureDescriptors if you just pass in names of features, but with our custom lookup table feature, we would need to construct it by directly passing the feature descriptors:

fzn = GraphNodeFeaturization([block, amass, meaning])
GraphNodeFeaturization encoding 3 features:
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Block
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Atomic mass
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor MeaningOfLife

(As a quick side note, the featurization is basically just a bundle for feature descriptors and associated codecs, which we can of course inspect:)

fzn.codecs
3-element Vector{OneHotOneCold}:
 OneHotOneCold(true, InlineStrings.String1["d", "f", "p", "s"])
 OneHotOneCold(false, [1.00794, 1.757605251627154, 3.0648413799902294, 5.344347188200727, 9.319257777745117, 16.250547067714365, 28.33705068558546, 49.413009802775825, 86.16442003300078, 150.25005174257413, 261.9999999999999])
 OneHotOneCold(true, Any[0, 42])

Featurizing structures

We've already seen how we can retrieve and encode values of individual features for atomic structures. We can of course do this for featurizations too:

encode(triangle_C, fzn)
3×16 Matrix{Float64}:
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0
 0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0

If we want to attach the encoded features to the graph, we can use the featurize function, which returns a FeaturizedAtoms object. This is the recommended approach for preprocessing many structures, as a serialized FeaturizedAtoms object stores the structure, the featurization, and the encoded features so you know what you've encoded it and how, and have the ability to decode.

	featurize(WS2, fzn)
AssertionError("Feature MeaningOfLife cannot encode some element(s) in this structure!")

Oops! What happened here?!?

...Our custom feature can't actually encode values for tungsten or sulfur. How could we have known this ahead of time? There's a function for that!

encodable_elements(fzn)
2-element Vector{InlineStrings.String3}:
 "C"
 "As"

Let's try a different featurization; we'll replace MeaningOfLife with electronegativity.

new_fzn = GraphNodeFeaturization(["Block", "Atomic mass", "X"])
GraphNodeFeaturization encoding 3 features:
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Block
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor Atomic mass
	ChemistryFeaturization.ElementFeature.ElementFeatureDescriptor X
encodable_elements(new_fzn)
100-element Vector{InlineStrings.String3}:
 "H"
 "Li"
 "Be"
 "B"
 "C"
 "N"
 "O"
 ⋮
 "Cf"
 "Es"
 "Fm"
 "Md"
 "No"
 "Lr"

That looks better!

featurized_WS2 = featurize(WS2, new_fzn)
FeaturizedAtoms:
	Atoms: AtomGraphs.AtomGraph{Xtals.Crystal} ../files/mp-224 with 6 nodes, 6 edges
	Featurization: GraphNodeFeaturization encoding 3 features
propertynames(featurized_WS2)
(:atoms, :featurization, :encoded_features)

As promised, we can still decode!

decode(featurized_WS2)
Dict{Integer, Dict{String, Any}} with 6 entries:
  5 => Dict("X"=>(2.34, 2.668), "Block"=>"p", "Atomic mass"=>(28.3371, 49.413))
  4 => Dict("X"=>(2.34, 2.668), "Block"=>"p", "Atomic mass"=>(28.3371, 49.413))
  6 => Dict("X"=>(2.34, 2.668), "Block"=>"p", "Atomic mass"=>(28.3371, 49.413))
  2 => Dict("X"=>(2.34, 2.668), "Block"=>"d", "Atomic mass"=>(150.25, 262.0))
  3 => Dict("X"=>(2.34, 2.668), "Block"=>"p", "Atomic mass"=>(28.3371, 49.413))
  1 => Dict("X"=>(2.34, 2.668), "Block"=>"d", "Atomic mass"=>(150.25, 262.0))