Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chainer graph #398

Merged
merged 54 commits into from
Sep 25, 2019
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
bd0a049
add dataset class which specifies batch method.
Sep 3, 2019
b3ae8ef
add files
Sep 3, 2019
58036a6
adapt sparse relgcn to dataset class
knshnb Sep 4, 2019
b47448b
Fix variable name
knshnb Sep 4, 2019
ddd4c45
Support RelGCN in dataset class
knshnb Sep 4, 2019
a787329
Refactor
knshnb Sep 4, 2019
3ec27db
Support GIN in dataset class
knshnb Sep 4, 2019
1a24f0c
arrange directory structure
knshnb Sep 4, 2019
21dc6a9
change RelGCNUpdate interface
knshnb Sep 4, 2019
01e5484
support GINSparse
knshnb Sep 4, 2019
1dd965c
concat node num option
knshnb Sep 5, 2019
121ff95
padding converter for sparse pattern
knshnb Sep 5, 2019
5995f2a
add cache directory to .gitignore
knshnb Sep 5, 2019
8ab6b5b
implement scatter general readout
knshnb Sep 5, 2019
9639804
store n_nodes in graph data
knshnb Sep 10, 2019
941645f
set weight tying option false by default in GIN
knshnb Sep 9, 2019
f81489f
don't use relu in last layer of GIN
knshnb Sep 9, 2019
1c43011
fix constructor of SparseGraphData
knshnb Sep 10, 2019
adbe443
support training on cora dataset
knshnb Sep 10, 2019
72553d5
support training on citeseer dataset
knshnb Sep 10, 2019
310d334
support training on reddit dataset
knshnb Sep 10, 2019
b2f322a
implement NodeClassifier
knshnb Sep 11, 2019
96e8cbc
fix device
knshnb Sep 11, 2019
0af27f5
fix on GPU
knshnb Sep 11, 2019
105d8d6
fix GIN
knshnb Sep 11, 2019
f5fe033
add dropout as command-line argment
knshnb Sep 12, 2019
74f5f96
add first_mlp to GINSparse
knshnb Sep 12, 2019
892e0fc
add reddit data to gitignore
knshnb Sep 12, 2019
85aa52a
support node classification on GIN (dense)
knshnb Sep 12, 2019
7be3f4d
change commandline argment default values
knshnb Sep 12, 2019
99844c1
Fix reddit input feature type to float32
knshnb Sep 12, 2019
4226fe2
efficient networkx preprocess
knshnb Sep 13, 2019
e2286ca
delete unnecessary
knshnb Sep 13, 2019
da33109
fix BaseSparseNetworkPreprocessor
knshnb Sep 13, 2019
1510f91
support coo matrix on gin
knshnb Sep 13, 2019
efe666f
print train label num
knshnb Sep 13, 2019
f0127cf
efficient read for reddit coo data
knshnb Sep 13, 2019
f064792
fix on GPU
knshnb Sep 13, 2019
0a9166f
fix variable name
knshnb Sep 17, 2019
2e84983
add readme for train_network_graph
knshnb Sep 19, 2019
4a51521
add links of datasets to readme
knshnb Sep 19, 2019
5ec8919
use NumpyTupleDataset chemical padding pattern
knshnb Sep 19, 2019
86d88f2
use NumpyTupleDataset in GGNN and RelGCN
knshnb Sep 19, 2019
70aa90f
add docstring
knshnb Sep 19, 2019
f1a8bf1
change network dataset directory under examples/
knshnb Sep 19, 2019
52470f8
save NumpyTupleDataset
knshnb Sep 20, 2019
628afa9
change file name
knshnb Sep 20, 2019
970401c
use mlp inside models
knshnb Sep 20, 2019
337cbfd
wrap padding pattern model by another Chain
knshnb Sep 20, 2019
56c0fd7
Merge remote-tracking branch 'upstream/master' into chainer-graph
knshnb Sep 20, 2019
86dab65
add reference to network datasets
knshnb Sep 20, 2019
dff0925
delete unnecessary modification
knshnb Sep 20, 2019
d6be7f6
fix weight_tying option
knshnb Sep 20, 2019
56218b1
move padding model wrapper under example
knshnb Sep 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ We test supporting the brand-new Graph Warp Module (GWM) [18]-attached models fo

The following datasets are currently supported:

### Chemical
- QM9 [7, 8]
- Tox21 [9]
- MoleculeNet [11]
- ZINC (only 250k dataset) [12, 13]
- User (own) dataset

### Network
- cora [21]
- citeseer [22]
- reddit [23]

## Research Projects

If you use Chainer Chemistry in your research, feel free to submit a
Expand Down Expand Up @@ -206,3 +212,9 @@ papers. Use the library at your own risk.
.

[20] Marc Brockschmidt, ``GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation'', arXiv:1906.12192 [cs.ML], 2019.

[21] McCallum, Andrew Kachites and Nigam, Kamal and Rennie, Jason and Seymore, Kristie, Automating the Construction of Internet Portals with Machine Learning. *Information Retrieval*, 2000.

[22] C. Lee Giles and Kurt D. Bollacker and Steve Lawrence, CiteSeer: An Automatic Citation Indexing System. *Proceedings of the Third ACM Conference on Digital Libraries*, 1998.

[23] William L. Hamilton and Zhitao Ying and Jure Leskovec, Inductive Representation Learning on Large Graphs. *Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, 4-9 December 2017*
Empty file.
73 changes: 73 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/base_graph_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy
import chainer


class BaseGraphData(object):
"""Base class of graph data """

def __init__(self, *args, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def to_device(self, device):
"""Send self to `device`

Args:
device (chainer.backend.Device): device

Returns:
self sent to `device`
"""
for k, v in self.__dict__.items():
if isinstance(v, (numpy.ndarray)):
setattr(self, k, device.send(v))
elif isinstance(v, (chainer.utils.CooMatrix)):
data = device.send(v.data.array)
row = device.send(v.row)
col = device.send(v.col)
device_coo_matrix = chainer.utils.CooMatrix(
data, row, col, v.shape, order=v.order)
setattr(self, k, device_coo_matrix)
return self


class PaddingGraphData(BaseGraphData):
"""Graph data class for padding pattern

Args:
x (numpy.ndarray): input node feature
adj (numpy.ndarray): adjacency matrix
y (int or numpy.ndarray): graph or node label
"""

def __init__(self, x=None, adj=None, super_node=None, pos=None, y=None,
**kwargs):
self.x = x
self.adj = adj
self.super_node = super_node
self.pos = pos
self.y = y
self.n_nodes = x.shape[0]
super(PaddingGraphData, self).__init__(**kwargs)


class SparseGraphData(BaseGraphData):
"""Graph data class for sparse pattern

Args:
x (numpy.ndarray): input node feature
edge_index (numpy.ndarray): sources and destinations of edges
edge_attr (numpy.ndarray): attribution of edges
y (int or numpy.ndarray): graph or node label
"""

def __init__(self, x=None, edge_index=None, edge_attr=None,
pos=None, super_node=None, y=None, **kwargs):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.pos = pos
self.super_node = super_node
self.y = y
self.n_nodes = x.shape[0]
super(SparseGraphData, self).__init__(**kwargs)
134 changes: 134 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/base_graph_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import chainer
import numpy
from chainer._backend import Device
from chainer_chemistry.dataset.graph_dataset.base_graph_data import \
BaseGraphData
from chainer_chemistry.dataset.graph_dataset.feature_converters \
import batch_with_padding, batch_without_padding, concat, shift_concat, \
concat_with_padding, shift_concat_with_padding


class BaseGraphDataset(object):
"""Base class of graph dataset (list of graph data)"""
_pattern = ''
_feature_entries = []
_feature_batch_method = []

def __init__(self, data_list, *args, **kwargs):
self.data_list = data_list

def register_feature(self, key, batch_method, skip_if_none=True):
"""Register feature with batch method

Args:
key (str): name of the feature
batch_method (function): batch method
skip_if_none (bool, optional): If true, skip if `batch_method` is
None. Defaults to True.
"""
if skip_if_none and getattr(self.data_list[0], key, None) is None:
return
self._feature_entries.append(key)
self._feature_batch_method.append(batch_method)

def update_feature(self, key, batch_method):
"""Update batch method of the feature
Args:
key (str): name of the feature
batch_method (function): batch method
"""

index = self._feature_entries.index(key)
self._feature_batch_method[index] = batch_method

def __len__(self):
return len(self.data_list)

def __getitem__(self, item):
return self.data_list[item]

def converter(self, batch, device=None):
"""Converter

Args:
batch (list[BaseGraphData]): list of graph data
device (int, optional): specifier of device. Defaults to None.

Returns:
self sent to `device`
"""
if not isinstance(device, Device):
device = chainer.get_device(device)
batch = [method(name, batch, device=device) for name, method in
zip(self._feature_entries, self._feature_batch_method)]
data = BaseGraphData(
**{key: value for key, value in zip(self._feature_entries, batch)})
return data


class PaddingGraphDataset(BaseGraphDataset):
"""Graph dataset class for padding pattern"""
_pattern = 'padding'

def __init__(self, data_list):
super(PaddingGraphDataset, self).__init__(data_list)
self.register_feature('x', batch_with_padding)
self.register_feature('adj', batch_with_padding)
self.register_feature('super_node', batch_with_padding)
self.register_feature('pos', batch_with_padding)
self.register_feature('y', batch_without_padding)
self.register_feature('n_nodes', batch_without_padding)


class SparseGraphDataset(BaseGraphDataset):
"""Graph dataset class for sparse pattern"""
_pattern = 'sparse'

def __init__(self, data_list):
super(SparseGraphDataset, self).__init__(data_list)
self.register_feature('x', concat)
self.register_feature('edge_index', shift_concat)
self.register_feature('edge_attr', concat)
self.register_feature('super_node', concat)
self.register_feature('pos', concat)
self.register_feature('y', batch_without_padding)
self.register_feature('n_nodes', batch_without_padding)

def converter(self, batch, device=None):
"""Converter

add `self.batch`, which represents the index of the graph each node
belongs to.

Args:
batch (list[BaseGraphData]): list of graph data
device (int, optional): specifier of device. Defaults to None.

Returns:
self sent to `device`
"""
data = super(SparseGraphDataset, self).converter(batch, device=device)
if not isinstance(device, Device):
device = chainer.get_device(device)
data.batch = numpy.concatenate([
numpy.full((data.x.shape[0]), i, dtype=numpy.int)
for i, data in enumerate(batch)
])
data.batch = device.send(data.batch)
return data

# for experiment
# use converter for the normal use
def converter_with_padding(self, batch, device=None):
self.update_feature('x', concat_with_padding)
self.update_feature('edge_index', shift_concat_with_padding)
data = super(SparseGraphDataset, self).converter(batch, device=device)
if not isinstance(device, Device):
device = chainer.get_device(device)
max_n_nodes = max([data.x.shape[0] for data in batch])
data.batch = numpy.concatenate([
numpy.full((max_n_nodes), i, dtype=numpy.int)
for i, data in enumerate(batch)
])
data.batch = device.send(data.batch)
return data
115 changes: 115 additions & 0 deletions chainer_chemistry/dataset/graph_dataset/feature_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy
from chainer.dataset.convert import _concat_arrays


def batch_with_padding(name, batch, device=None, pad=0):
"""Batch with padding (increase ndim by 1)

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = _concat_arrays(
[getattr(example, name) for example in batch], pad)
return device.send(feat)


def batch_without_padding(name, batch, device=None):
"""Batch without padding (increase ndim by 1)

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = _concat_arrays(
[getattr(example, name) for example in batch], None)
return device.send(feat)


def concat_with_padding(name, batch, device=None, pad=0):
"""Concat without padding (ndim does not increase)

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = batch_with_padding(name, batch, device=device, pad=pad)
a, b = feat.shape
return feat.reshape((a * b))


def concat(name, batch, device=None, axis=0):
"""Concat with padding (ndim does not increase)

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.
pad (int, optional): padding value. Defaults to 0.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
feat = numpy.concatenate([getattr(data, name) for data in batch],
axis=axis)
return device.send(feat)


def shift_concat(name, batch, device=None, shift_attr='x', shift_axis=1):
"""Concat with index shift (ndim does not increase)

Concatenate graphs into a big one.
Used for sparse pattern batching.

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
shift_index_array = numpy.cumsum(
numpy.array([0] + [getattr(data, shift_attr).shape[0]
for data in batch]))
feat = numpy.concatenate([
getattr(data, name) + shift_index_array[i]
for i, data in enumerate(batch)], axis=shift_axis)
return device.send(feat)


def shift_concat_with_padding(name, batch, device=None, shift_attr='x',
shift_axis=1):
"""Concat with index shift and padding (ndim does not increase)

Concatenate graphs into a big one.
Used for sparse pattern batching.

Args:
name (str): propaty name of graph data
batch (list[BaseGraphData]): list of base graph data
device (chainer.backend.Device, optional): device. Defaults to None.

Returns:
BaseGraphDataset: graph dataset sent to `device`
"""
max_n_nodes = max([data.x.shape[0] for data in batch])
shift_index_array = numpy.arange(0, len(batch) * max_n_nodes, max_n_nodes)
feat = numpy.concatenate([
getattr(data, name) + shift_index_array[i]
for i, data in enumerate(batch)], axis=shift_axis)
return device.send(feat)
Loading