(资料图片)
- 引言
- 题外话(MapFeatures使用)
- 节点特征变换
- 边特征变换
- 传入额外参数
- 题外话(MapFeatures使用)
- 问题
- 问题demo
- 解决方案
引言
由于图数据结构问题,直接使用Tensorflow
的一些层是无法直接处理图数据的,需要借用TF-GNN
框架下的MapFeatures
对图数据中的节点特征或是边特征进行变换。
题外话(MapFeatures使用)
节点特征变换
from tensorflow.keras.layers import BatchNormalizationfrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name): features = node_set.features return BatchNormalization()(features["hidden_state"])graph = MapFeatures(node_sets_fn=node_sets_fn)(graph)
边特征变换
from tensorflow_gnn.keras.layers import MapFeatures# Hashes edge features called "id", leaves others unchanged:def edge_sets_fn(edge_set, *, edge_set_name): features = edge_set.get_features_dict() ids = features.pop("id") num_bins = 100_000 if edge_set_name == "views" else 20_000 hashed_ids = tf.keras.layers.Hashing(num_bins=num_bins)(ids) features["hashed_id"] = hashed_ids return featuresgraph = MapFeatures(edge_sets_fn=edge_sets_fn)(graph)
传入额外参数
from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim): features = node_set.features return Dense(dim)(features["hidden_state"])graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64))(graph)
问题
就是在使用MapFeatures
时,如果循环使用则会在存储模型的时候报错:ValueError: Unable to create dataset (name already exists)
问题demo
from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim): features = node_set.features return Dense(dim)(features["hidden_state"])for ln in range(layer_num): graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64))(graph)
解决方案
最后发现是在使用MapFeatures
时,使用层时如Dense
需要区分每一次变换时的层名
from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim,name): features = node_set.features return Dense(dim, name=f"Dense_{name}")(features["hidden_state"])for ln in range(layer_num): graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64,name=ln))(graph)