Skip to content

Commit

Permalink
make serving tree use tree_output_dir as directory, node_edge_output_…
Browse files Browse the repository at this point in the history
…file may be odps table
  • Loading branch information
tiankongdeguiji committed Oct 4, 2024
1 parent 3146bf0 commit ec8ef1d
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 36 deletions.
12 changes: 8 additions & 4 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,18 @@ python -m tzrec.tools.tdm.init_tree \
--cate_id_field cate_id \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/init_tree
--tree_output_dir data/init_tree
```

- --item_input_path: 建树用的item特征文件
- --item_id_field: 代表item的id的列名
- --cate_id_field: 代表item的类别的列名
- --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)初始树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表``本地txt GL`两种
- ODPS GL表:设置形如`odps://{project}/tables/{tb_prefix}`,将会产出用于TDM训练负采样的GL Node表`odps://{project}/tables/{tb_prefix}_node_table`、GL Edge表`odps://{project}/tables/{tb_prefix}_edge_table`、用于离线检索的GL Edge表`odps://{project}/tables/{tb_prefix}_predict_edge_table`
- 本地txt GL表:设置的为目录, 将在目录下产出用于TDM训练负采样的GL Node表`node_table.txt`,GL Edge表`edge_table.txt`、用于离线检索的GL Edge表`predict_edge_table.txt`
- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务
- --n_cluster: (可选,默认为2)树的分叉数

#### 训练
Expand Down Expand Up @@ -115,6 +118,7 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
--embedding_field item_emb \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/learnt_tree \
--tree_output_dir data/learnt_tree \
--parallel 16
```

Expand All @@ -123,8 +127,8 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
- --embedding_field: 代表item embedding的列名
- --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表``本地txt GL`两种,同初始树
- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务
- --n_cluster: (可选,默认为2)树的分叉数
- --parllel: (可选,默认为16)聚类时CPU并行数

Expand Down
2 changes: 2 additions & 0 deletions tzrec/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ def load_config_for_test(
f"--attr_fields {','.join(attr_fields)} "
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {test_dir}/init_tree "
f"--tree_output_dir {test_dir}/init_tree "
)
p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_init_tree.txt"))
p.wait(600)
Expand Down Expand Up @@ -1212,6 +1213,7 @@ def test_tdm_cluster_train_eval(
f"--attr_fields {','.join(attr_fields)} "
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {os.path.join(test_dir, 'learnt_tree')} "
f"--tree_output_dir {os.path.join(test_dir, 'learnt_tree')} "
f"--parallel 1 "
)
p = misc_util.run_cmd(
Expand Down
13 changes: 5 additions & 8 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
help="The column names representing the raw features of item in the file.",
)
parser.add_argument(
"--tree_output_file",
"--tree_output_dir",
type=str,
default=None,
help="The tree output file.",
Expand Down Expand Up @@ -78,16 +78,12 @@
item_id_field=args.item_id_field,
attr_fields=args.attr_fields,
raw_attr_fields=args.raw_attr_fields,
output_file=args.tree_output_file,
output_dir=args.tree_output_dir,
embedding_field=args.embedding_field,
parallel=args.parallel,
n_cluster=args.n_cluster,
)
if args.tree_output_file:
save_tree = True
else:
save_tree = False
root = cluster.train(save_tree)
root = cluster.train()
logger.info("Tree cluster done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
Expand All @@ -96,5 +92,6 @@
)
tree_search.save()
tree_search.save_predict_edge()
tree_search.save_serving_tree()
if args.tree_output_dir:
tree_search.save_serving_tree(args.tree_output_dir)
logger.info("Save nodes and edges table done.")
18 changes: 9 additions & 9 deletions tzrec/tools/tdm/gen_tree/tree_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ class TreeCluster:
"""Cluster based on emb vec.
Args:
item_input_path(str): The file path where the item information is stored.
item_id_field(str): The column name representing item_id in the file.
attr_fields(List[str]): The column names representing the features in the file.
output_file(str): The output file.
parallel(int): The number of CPU cores for parallel processing.
n_cluster(int): The branching factor of the nodes in the tree.
item_input_path (str): The file path where the item information is stored.
item_id_field (str): The column name representing item_id in the file.
attr_fields (List[str]): The column names representing the features in the file.
output_dir (str): The output file.
parallel (int): The number of CPU cores for parallel processing.
n_cluster (int): The branching factor of the nodes in the tree.
"""

def __init__(
Expand All @@ -45,7 +45,7 @@ def __init__(
item_id_field: str,
attr_fields: Optional[str] = None,
raw_attr_fields: Optional[str] = None,
output_file: Optional[str] = None,
output_dir: Optional[str] = None,
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
Expand All @@ -60,7 +60,7 @@ def __init__(
self.queue = None
self.timeout = 5
self.codes = None
self.output_file = output_file
self.output_dir = output_dir
self.n_clusters = n_cluster

self.item_id_field = item_id_field
Expand Down Expand Up @@ -140,7 +140,7 @@ def train(self, save_tree: bool = False) -> TDMTreeClass:
p.join()

assert queue.empty()
builder = tree_builder.TreeBuilder(self.output_file, self.n_clusters)
builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
root = builder.build(
self.ids, self.codes, self.attrs, self.raw_attrs, self.data, save_tree
)
Expand Down
6 changes: 3 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
cate_id_field: str,
attr_fields: Optional[str] = None,
raw_attr_fields: Optional[str] = None,
tree_output_file: Optional[str] = None,
tree_output_dir: Optional[str] = None,
n_cluster: int = 2,
) -> None:
self.item_input_path = item_input_path
Expand All @@ -49,7 +49,7 @@ def __init__(
self.attr_fields = [x.strip() for x in attr_fields.split(",")]
if raw_attr_fields:
self.raw_attr_fields = [x.strip() for x in raw_attr_fields.split(",")]
self.tree_output_file = tree_output_file
self.tree_output_dir = tree_output_dir
self.n_cluster = n_cluster

def generate(self, save_tree: bool = False) -> TDMTreeClass:
Expand Down Expand Up @@ -141,6 +141,6 @@ def gen_code(start: int, end: int, code: int, items: List[Item]) -> None:
)
data = np.array([[] for i in range(len(ids))])

builder = TreeBuilder(self.tree_output_file, self.n_cluster)
builder = TreeBuilder(self.tree_output_dir, self.n_cluster)
root = builder.build(ids, codes, attrs, raw_attrs, data, save_tree)
return root
6 changes: 4 additions & 2 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ def save_predict_edge(self) -> None:
for child in node.children:
f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n")

def save_serving_tree(self) -> None:
def save_serving_tree(self, tree_output_dir: str) -> None:
"""Save tree info for serving."""
with open(os.path.join(self.output_file, "serving_tree"), "w") as f:
if not os.path.exists(tree_output_dir):
os.makedirs(tree_output_dir)
with open(os.path.join(tree_output_dir, "serving_tree"), "w") as f:
f.write(f"{self.max_level + 1} {self.child_num}\n")
for _, nodes in enumerate(self.level_code):
for node in nodes:
Expand Down
2 changes: 1 addition & 1 deletion tzrec/tools/tdm/gen_tree/tree_search_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_cluster(self) -> None:
search = TreeSearch(output_file=self.test_dir, root=root, child_num=2)
search.save()
search.save_predict_edge()
search.save_serving_tree()
search.save_serving_tree(self.test_dir)

node_table = []
edge_table = []
Expand Down
15 changes: 6 additions & 9 deletions tzrec/tools/tdm/init_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
help="The column names representing the raw features of item in the file.",
)
parser.add_argument(
"--tree_output_file",
"--tree_output_dir",
type=str,
default=None,
help="The tree output file.",
help="The tree output directory.",
)
parser.add_argument(
"--node_edge_output_file",
Expand All @@ -73,14 +73,10 @@
cate_id_field=args.cate_id_field,
attr_fields=args.attr_fields,
raw_attr_fields=args.raw_attr_fields,
tree_output_file=args.tree_output_file,
tree_output_dir=args.tree_output_dir,
n_cluster=args.n_cluster,
)
if args.tree_output_file:
save_tree = True
else:
save_tree = False
root = generator.generate(save_tree)
root = generator.generate()
logger.info("Tree init done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
Expand All @@ -89,5 +85,6 @@
)
tree_search.save()
tree_search.save_predict_edge()
tree_search.save_serving_tree()
if args.tree_output_dir:
tree_search.save_serving_tree(args.tree_output_dir)
logger.info("Save nodes and edges table done.")

0 comments on commit ec8ef1d

Please sign in to comment.