Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add tdm serving tree & remove recall_num in gen tree #2

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions docs/source/quick_start/local_tutorial_tdm.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ python -m tzrec.tools.tdm.init_tree \
--cate_id_field cate_id \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/init_tree
--tree_output_dir data/init_tree
```

- --item_input_path: 建树用的item特征文件
- --item_id_field: 代表item的id的列名
- --cate_id_field: 代表item的类别的列名
- --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)初始树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GL的全称是什么?

- ODPS GL表:设置形如`odps://{project}/tables/{tb_prefix}`,将会产出用于TDM训练负采样的GL Node表`odps://{project}/tables/{tb_prefix}_node_table`、GL Edge表`odps://{project}/tables/{tb_prefix}_edge_table`、用于离线检索的GL Edge表`odps://{project}/tables/{tb_prefix}_predict_edge_table`
- 本地txt GL表:设置的为目录, 将在目录下产出用于TDM训练负采样的GL Node表`node_table.txt`,GL Edge表`edge_table.txt`、用于离线检索的GL Edge表`predict_edge_table.txt`
- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务
- --n_cluster: (可选,默认为2)树的分叉数

#### 训练
Expand Down Expand Up @@ -80,11 +82,13 @@ torchrun --master_addr=localhost --master_port=32555 \
-m tzrec.export \
--pipeline_config_path experiments/tdm_taobao_local/pipeline.config \
--export_dir experiments/tdm_taobao_local/export
--asset_files data/init_tree/serving_tree
```

- --pipeline_config_path: 导出用的配置文件
- --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint
- --export_dir: 导出到的模型目录
- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务

#### 导出item embedding

Expand Down Expand Up @@ -114,6 +118,7 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
--embedding_field item_emb \
--attr_fields cate_id,campaign_id,customer,brand,price \
--node_edge_output_file data/learnt_tree \
--tree_output_dir data/learnt_tree \
--parallel 16
```

Expand All @@ -122,9 +127,8 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \
- --embedding_field: 代表item embedding的列名
- --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致
- --tree_output_file: (可选)树的保存路径, 不输入不会保存
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种
- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率
- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种,同初始树
- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务
- --n_cluster: (可选,默认为2)树的分叉数
- --parllel: (可选,默认为16)聚类时CPU并行数

Expand Down Expand Up @@ -153,11 +157,13 @@ torchrun --master_addr=localhost --master_port=32555 \
-m tzrec.export \
--pipeline_config_path experiments/tdm_taobao_local_learnt/pipeline.config \
--export_dir experiments/tdm_taobao_local_learnt/export
--asset_files data/learnt_tree/serving_tree
```

- --pipeline_config_path: 导出用的配置文件
- --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint
- --export_dir: 导出到的模型目录
- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务

#### Recall评估

Expand All @@ -181,7 +187,7 @@ torchrun --master_addr=localhost --master_port=32555 \
- --predict_input_path: 预测输入数据的路径
- --predict_output_path: 预测输出数据的路径
- --gt_item_id_field: 文件中代表真实点击item_id的列名
- --recall_num:(可选, 默认为200) 召回的数量, 应与建树时输入保持一致
- --recall_num:(可选, 默认为200) 召回的数量
- --n_cluster:(可选, 默认为2) 数的分叉数量, 应与建树时输入保持一致
- --reserved_columns: 预测结果中要保留的输入列

Expand Down
7 changes: 7 additions & 0 deletions tzrec/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@
default=None,
help="directory where model should be exported to.",
)
parser.add_argument(
"--asset_files",
type=str,
default=None,
help="more files will be copy to export_dir.",
)
args, extra_args = parser.parse_known_args()

export(
args.pipeline_config_path,
export_dir=args.export_dir,
checkpoint_path=args.checkpoint_path,
asset_files=args.asset_files,
)
21 changes: 18 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools
import json
import os
import shutil
from collections import OrderedDict
from queue import Queue
from threading import Thread
Expand Down Expand Up @@ -747,7 +748,10 @@ def _script_model(


def export(
pipeline_config_path: str, export_dir: str, checkpoint_path: Optional[str] = None
pipeline_config_path: str,
export_dir: str,
checkpoint_path: Optional[str] = None,
asset_files: Optional[str] = None,
) -> None:
"""Export a EasyRec model.

Expand All @@ -756,6 +760,7 @@ def export(
export_dir (str): base directory where the model should be exported.
checkpoint_path (str, optional): if specified, will use this model instead of
model specified by model_dir in pipeline_config_path.
asset_files (str, optional): more files will be copy to export_dir.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more files will be copied to export_dir.
copy ->copied

"""
pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
ori_pipeline_config = copy.copy(pipeline_config)
Expand All @@ -766,6 +771,10 @@ def export(
if os.path.exists(export_dir):
raise RuntimeError(f"directory {export_dir} already exist.")

assets = []
if asset_files:
assets = asset_files.split(",")

data_config = pipeline_config.data_config
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)
Expand Down Expand Up @@ -832,13 +841,16 @@ def export(
for name, module in cpu_model.named_children():
if isinstance(module, MatchTower):
tower = ScriptWrapper(TowerWrapper(module, name))
tower_export_dir = os.path.join(export_dir, name.replace("_tower", ""))
_script_model(
ori_pipeline_config,
tower,
cpu_state_dict,
dataloader,
os.path.join(export_dir, name.replace("_tower", "")),
tower_export_dir,
)
for asset in assets:
shutil.copy(asset, tower_export_dir)
elif isinstance(cpu_model, TDM):
for name, module in cpu_model.named_children():
if isinstance(module, EmbeddingGroup):
Expand All @@ -857,7 +869,8 @@ def export(
dataloader,
export_dir,
)

for asset in assets:
shutil.copy(asset, export_dir)
else:
_script_model(
ori_pipeline_config,
Expand All @@ -866,6 +879,8 @@ def export(
dataloader,
export_dir,
)
for asset in assets:
shutil.copy(asset, export_dir)


def predict(
Expand Down
6 changes: 5 additions & 1 deletion tzrec/tests/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import argparse
import os
import sys
import unittest


Expand Down Expand Up @@ -48,4 +49,7 @@ def _gather_test_cases(args):
runner = unittest.TextTestRunner()
test_suite = _gather_test_cases(args)
if not args.list_tests:
runner.run(test_suite)
result = runner.run(test_suite)
failed, errored = len(result.failures), len(result.errors)
if failed > 0 or errored > 0:
sys.exit(1)
9 changes: 7 additions & 2 deletions tzrec/tests/train_eval_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def test_tdm_train_eval_export(self):
)
if self.success:
self.success = utils.test_export(
os.path.join(self.test_dir, "pipeline.config"), self.test_dir
os.path.join(self.test_dir, "pipeline.config"),
self.test_dir,
asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"),
)
if self.success:
self.success = utils.test_predict(
Expand All @@ -557,7 +559,7 @@ def test_tdm_train_eval_export(self):
embedding_field="item_emb",
)
if self.success:
with open(os.path.join(self.test_dir, "node_table.txt")) as f:
with open(os.path.join(self.test_dir, "init_tree/node_table.txt")) as f:
for line_number, line in enumerate(f):
if line_number == 1:
root_id = int(line.split("\t")[0])
Expand Down Expand Up @@ -586,6 +588,9 @@ def test_tdm_train_eval_export(self):
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt"))
)
self.assertTrue(
os.path.exists(os.path.join(self.test_dir, "export/serving_tree"))
)
self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result")))


Expand Down
24 changes: 16 additions & 8 deletions tzrec/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,16 +791,20 @@ def load_config_for_test(
f"--cate_id_field {cate_id} "
f"--attr_fields {','.join(attr_fields)} "
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {test_dir} "
f"--recall_num 1"
f"--node_edge_output_file {test_dir}/init_tree "
f"--tree_output_dir {test_dir}/init_tree "
)
p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_init_tree.txt"))
p.wait(600)

sampler_config.item_input_path = os.path.join(test_dir, "node_table.txt")
sampler_config.edge_input_path = os.path.join(test_dir, "edge_table.txt")
sampler_config.item_input_path = os.path.join(
test_dir, "init_tree/node_table.txt"
)
sampler_config.edge_input_path = os.path.join(
test_dir, "init_tree/edge_table.txt"
)
sampler_config.predict_edge_input_path = os.path.join(
test_dir, "predict_edge_table.txt"
test_dir, "init_tree/predict_edge_table.txt"
)

else:
Expand Down Expand Up @@ -874,7 +878,9 @@ def test_eval(pipeline_config_path: str, test_dir: str) -> bool:
return True


def test_export(pipeline_config_path: str, test_dir: str) -> bool:
def test_export(
pipeline_config_path: str, test_dir: str, asset_files: str = ""
) -> bool:
"""Run export integration test."""
port = misc_util.get_free_port()
log_dir = os.path.join(test_dir, "log_export")
Expand All @@ -884,8 +890,10 @@ def test_export(pipeline_config_path: str, test_dir: str) -> bool:
f"--nproc-per-node=2 --node_rank=0 --log_dir {log_dir} "
"-r 3 -t 3 tzrec/export.py "
f"--pipeline_config_path {pipeline_config_path} "
f"--export_dir {test_dir}/export"
f"--export_dir {test_dir}/export "
)
if asset_files:
cmd_str += f"--asset_files {asset_files}"

p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_export.txt"))
p.wait(600)
Expand Down Expand Up @@ -1205,8 +1213,8 @@ def test_tdm_cluster_train_eval(
f"--attr_fields {','.join(attr_fields)} "
f"--raw_attr_fields {','.join(raw_attr_fields)} "
f"--node_edge_output_file {os.path.join(test_dir, 'learnt_tree')} "
f"--tree_output_dir {os.path.join(test_dir, 'learnt_tree')} "
f"--parallel 1 "
f"--recall_num 1 "
)
p = misc_util.run_cmd(
cluster_cmd_str, os.path.join(test_dir, "log_tdm_cluster.txt")
Expand Down
22 changes: 6 additions & 16 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

import argparse
import math

from tzrec.tools.tdm.gen_tree.tree_cluster import TreeCluster
from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch
Expand Down Expand Up @@ -49,7 +48,7 @@
help="The column names representing the raw features of item in the file.",
)
parser.add_argument(
"--tree_output_file",
"--tree_output_dir",
type=str,
default=None,
help="The tree output file.",
Expand All @@ -66,12 +65,6 @@
default=16,
help="The number of CPU cores for parallel processing.",
)
parser.add_argument(
"--recall_num",
type=int,
default=200,
help="Recall number per item when retrieval.",
)
parser.add_argument(
"--n_cluster",
type=int,
Expand All @@ -85,23 +78,20 @@
item_id_field=args.item_id_field,
attr_fields=args.attr_fields,
raw_attr_fields=args.raw_attr_fields,
output_file=args.tree_output_file,
output_dir=args.tree_output_dir,
embedding_field=args.embedding_field,
parallel=args.parallel,
n_cluster=args.n_cluster,
)
if args.tree_output_file:
save_tree = True
else:
save_tree = False
root = cluster.train(save_tree)
root = cluster.train()
logger.info("Tree cluster done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
)
tree_search.save()
first_recall_layer = int(math.ceil(math.log(2 * args.recall_num, args.n_cluster)))
tree_search.save_predict_edge(first_recall_layer)
tree_search.save_predict_edge()
if args.tree_output_dir:
tree_search.save_serving_tree(args.tree_output_dir)
logger.info("Save nodes and edges table done.")
18 changes: 9 additions & 9 deletions tzrec/tools/tdm/gen_tree/tree_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ class TreeCluster:
"""Cluster based on emb vec.
Args:
item_input_path(str): The file path where the item information is stored.
item_id_field(str): The column name representing item_id in the file.
attr_fields(List[str]): The column names representing the features in the file.
output_file(str): The output file.
parallel(int): The number of CPU cores for parallel processing.
n_cluster(int): The branching factor of the nodes in the tree.
item_input_path (str): The file path where the item information is stored.
item_id_field (str): The column name representing item_id in the file.
attr_fields (List[str]): The column names representing the features in the file.
output_dir (str): The output file.
parallel (int): The number of CPU cores for parallel processing.
n_cluster (int): The branching factor of the nodes in the tree.
"""

def __init__(
Expand All @@ -45,7 +45,7 @@ def __init__(
item_id_field: str,
attr_fields: Optional[str] = None,
raw_attr_fields: Optional[str] = None,
output_file: Optional[str] = None,
output_dir: Optional[str] = None,
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
Expand All @@ -60,7 +60,7 @@ def __init__(
self.queue = None
self.timeout = 5
self.codes = None
self.output_file = output_file
self.output_dir = output_dir
self.n_clusters = n_cluster

self.item_id_field = item_id_field
Expand Down Expand Up @@ -140,7 +140,7 @@ def train(self, save_tree: bool = False) -> TDMTreeClass:
p.join()

assert queue.empty()
builder = tree_builder.TreeBuilder(self.output_file, self.n_clusters)
builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
root = builder.build(
self.ids, self.codes, self.attrs, self.raw_attrs, self.data, save_tree
)
Expand Down
2 changes: 1 addition & 1 deletion tzrec/tools/tdm/gen_tree/tree_cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_cluster(self) -> None:
embedding_field="item_emb",
attr_fields="cate_id,str_a",
raw_attr_fields="raw_1",
output_file=None,
output_dir=None,
parallel=1,
n_cluster=2,
)
Expand Down
Loading
Loading