feat: 模型演化添加点击节点展开数据集和项目

This commit is contained in:
cp3hnu 2024-07-01 11:45:39 +08:00
parent e88cc9e197
commit 2cc70b1f00
4 changed files with 180 additions and 66 deletions

View File

@ -1,9 +1,15 @@
/*
* @Author:
* @Date: 2024-06-07 11:24:10
* @Description:
*/
import { useEffectWhen } from '@/hooks';
import { ResourceVersionData } from '@/pages/Dataset/config';
import { getModelAtlasReq } from '@/services/dataset/index.js';
import themes from '@/styles/theme.less';
import { to } from '@/utils/promise';
import G6, { G6GraphEvent, Graph } from '@antv/g6';
import G6, { G6GraphEvent, Graph, INode } from '@antv/g6';
// @ts-ignore
import { Flex, Select } from 'antd';
import { useEffect, useRef, useState } from 'react';
@ -11,7 +17,15 @@ import GraphLegend from '../GraphLegend';
import NodeTooltips from '../NodeTooltips';
import styles from './index.less';
import type { ModelDepsData, ProjectDependency, TrainDataset } from './utils';
import { getGraphData, nodeFontSize, nodeHeight, nodeWidth, normalizeTreeData } from './utils';
import {
NodeType,
getGraphData,
nodeFontSize,
nodeHeight,
nodeWidth,
normalizeTreeData,
traverseHierarchically,
} from './utils';
type modeModelEvolutionProps = {
resourceId: number;
@ -37,6 +51,8 @@ function ModelEvolution({
const [hoverNodeData, setHoverNodeData] = useState<
ModelDepsData | ProjectDependency | TrainDataset | undefined
>(undefined);
const apiData = useRef<ModelDepsData | undefined>(undefined); // 接口返回的树形结构
const hierarchyNodes = useRef<ModelDepsData[]>([]); // 层级迭代树形结构,得到的节点列表
useEffect(() => {
initGraph();
@ -111,18 +127,7 @@ function ModelEvolution({
},
},
modes: {
default: [
'drag-canvas',
'zoom-canvas',
// {
// type: 'collapse-expand',
// onChange(item?: Item, collapsed?: boolean) {
// const data = item!.getModel();
// data.collapsed = collapsed;
// return true;
// },
// },
],
default: ['drag-canvas', 'zoom-canvas'],
},
});
@ -161,11 +166,26 @@ function ModelEvolution({
});
graph.on('node:click', (e: G6GraphEvent) => {
const nodeItem = e.item;
const nodeItem = e.item as INode;
const model = nodeItem.getModel() as ModelDepsData | ProjectDependency | TrainDataset;
const { model_type } = model;
switch (model_type) {
if (
model_type === NodeType.Project ||
model_type === NodeType.TrainDataset ||
model_type === NodeType.TestDataset ||
!apiData.current ||
!hierarchyNodes.current
) {
return;
}
setShowNodeTooltip(false);
setEnterTooltip(false);
toggleExpended(model.id);
const graphData = getGraphData(apiData.current, hierarchyNodes.current);
graph.data(graphData);
graph.render();
graph.fitView();
});
// 鼠标滚轮缩放时,隐藏 tooltip
@ -175,6 +195,17 @@ function ModelEvolution({
});
};
// toggle 展开
const toggleExpended = (id: string) => {
const nodes = hierarchyNodes.current;
for (const node of nodes) {
if (node.id === id) {
node.expanded = !node.expanded;
break;
}
}
};
const handleTooltipsMouseEnter = () => {
setEnterTooltip(true);
};
@ -192,7 +223,9 @@ function ModelEvolution({
const [res] = await to(getModelAtlasReq(params));
if (res && res.data) {
const data = normalizeTreeData(res.data);
const graphData = getGraphData(data);
apiData.current = data;
hierarchyNodes.current = traverseHierarchically(data);
const graphData = getGraphData(data, hierarchyNodes.current);
graph.data(graphData);
graph.render();

View File

@ -6,21 +6,22 @@ import Hierarchy from '@antv/hierarchy';
export const nodeWidth = 90;
export const nodeHeight = 40;
export const vGap = nodeHeight + 20;
export const hGap = nodeWidth;
export const hGap = nodeHeight + 20;
export const ellipseWidth = nodeWidth;
export const labelPadding = 30;
export const nodeFontSize = 8;
export const datasetHGap = 20;
// 数据集节点
const datasetNodes: NodeConfig[] = [];
export enum NodeType {
current = 'current',
parent = 'parent',
children = 'children',
project = 'project',
trainDataset = 'trainDataset',
testDataset = 'testDataset',
Current = 'Current', // 当前模型
Parent = 'Parent', // 父模型
Children = 'Children', // 子模型
Project = 'Project', // 项目
TrainDataset = 'TrainDataset', // 训练数据集
TestDataset = 'TestDataset', // 测试数据集
}
export type Rect = {
@ -40,14 +41,14 @@ export interface TrainDataset extends NodeConfig {
dataset_id: number;
dataset_name: string;
dataset_version: string;
model_type: NodeType.testDataset | NodeType.trainDataset;
model_type: NodeType.TestDataset | NodeType.TrainDataset;
}
export interface ProjectDependency extends NodeConfig {
url: string;
name: string;
branch: string;
model_type: NodeType.project;
model_type: NodeType.Project;
}
export type ModalDetail = {
@ -66,9 +67,9 @@ export interface ModelDepsAPIData {
version: string;
workflow_id: number;
exp_ins_id: number;
model_type: NodeType.children | NodeType.current | NodeType.parent;
model_type: NodeType.Children | NodeType.Current | NodeType.Parent;
current_model_name: string;
project_dependency: ProjectDependency;
project_dependency?: ProjectDependency;
test_dataset: TrainDataset[];
train_dataset: TrainDataset[];
train_task: TrainTask;
@ -79,16 +80,22 @@ export interface ModelDepsAPIData {
export interface ModelDepsData extends Omit<ModelDepsAPIData, 'children_models'>, TreeGraphData {
children: ModelDepsData[];
expanded: boolean; // 是否展开
level: number; // 层级,从 0 开始
datasetLen: number; // 数据集数量
}
// 规范化子数据
export function normalizeChildren(data: ModelDepsData[]) {
if (Array.isArray(data)) {
data.forEach((item) => {
item.model_type = NodeType.children;
item.model_type = NodeType.Children;
item.expanded = false;
item.level = 0;
item.datasetLen = item.train_dataset.length + item.test_dataset.length;
item.id = `$M_${item.current_model_id}_${item.version}`;
item.label = getLabel(item);
item.style = getStyle(NodeType.children);
item.style = getStyle(NodeType.Children);
normalizeChildren(item.children);
});
}
@ -111,22 +118,22 @@ export function getLabel(node: ModelDepsData | ModelDepsAPIData) {
export function getStyle(model_type: NodeType) {
let fill = '';
switch (model_type) {
case NodeType.current:
case NodeType.Current:
fill = 'l(0) 0:#72a1ff 1:#1664ff';
break;
case NodeType.parent:
case NodeType.Parent:
fill = 'l(0) 0:#93dfd1 1:#43c9b1';
break;
case NodeType.children:
case NodeType.Children:
fill = 'l(0) 0:#72b4ff 1:#169aff';
break;
case NodeType.project:
case NodeType.Project:
fill = 'l(0) 0:#b3a9ff 1:#8981ff';
break;
case NodeType.trainDataset:
case NodeType.TrainDataset:
fill = '#a5d878';
break;
case NodeType.testDataset:
case NodeType.TestDataset:
fill = '#d8b578';
break;
default:
@ -145,11 +152,15 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
}) as ModelDepsData;
// 设置当前模型的数据
normalizedData.model_type = NodeType.current;
normalizedData.model_type = NodeType.Current;
normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`;
normalizedData.label = getLabel(normalizedData);
normalizedData.style = getStyle(NodeType.current);
normalizedData.style = getStyle(NodeType.Current);
normalizedData.expanded = true;
normalizedData.datasetLen =
normalizedData.train_dataset.length + normalizedData.test_dataset.length;
normalizeChildren(normalizedData.children as ModelDepsData[]);
normalizedData.level = 0;
// 将 parent_models 转换成树形结构
let parent_models = normalizedData.parent_models || [];
@ -157,10 +168,13 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
const parent = parent_models[0];
normalizedData = {
...parent,
model_type: NodeType.parent,
expanded: false,
level: 0,
datasetLen: parent.train_dataset.length + parent.test_dataset.length,
model_type: NodeType.Parent,
id: `$M_${parent.current_model_id}_${parent.version}`,
label: getLabel(parent),
style: getStyle(NodeType.parent),
style: getStyle(NodeType.Parent),
children: [
{
...normalizedData,
@ -174,13 +188,34 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
}
// 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据
export function getGraphData(data: ModelDepsData): GraphData {
export function getGraphData(data: ModelDepsData, hierarchyNodes: ModelDepsData[]): GraphData {
const config = {
direction: 'LR',
getHeight: () => nodeHeight,
getWidth: () => nodeWidth,
getVGap: () => vGap / 2,
getHGap: () => hGap / 2,
getVGap: (node: NodeConfig) => {
const model = node as ModelDepsData;
const { model_type, expanded, project_dependency } = model;
if (model_type === NodeType.Current || model_type === NodeType.Parent) {
return vGap / 2;
}
const selfGap = expanded && project_dependency?.url ? nodeHeight + vGap : 0;
const nextNode = getSameHierarchyNextNode(model, hierarchyNodes);
if (!nextNode) {
return vGap / 2;
}
const nextGap = nextNode.expanded === true && nextNode.datasetLen > 0 ? nodeHeight + vGap : 0;
return (selfGap + nextGap + vGap) / 2;
},
getHGap: (node: NodeConfig) => {
const model = node as ModelDepsData;
return (
(getHierarchyWidth(model.level, hierarchyNodes) +
getHierarchyWidth(model.level + 1, hierarchyNodes) +
hGap) /
2
);
},
};
// 树形布局计算出坐标
@ -191,11 +226,11 @@ export function getGraphData(data: ModelDepsData): GraphData {
Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => {
const data = node.data as ModelDepsData;
// 当前模型显示数据集和项目
if (data.model_type === NodeType.current) {
if (data.expanded === true) {
addDatasetDependency(data, node, nodes, edges);
addProjectDependency(data, node, nodes, edges);
} else if (data.model_type === NodeType.children) {
adjustDatasetPosition(node);
} else if (data.model_type === NodeType.Children) {
// adjustDatasetPosition(node);
}
nodes.push({
...data,
@ -219,16 +254,16 @@ const addDatasetDependency = (
nodes: NodeConfig[],
edges: EdgeConfig[],
) => {
const { train_dataset, test_dataset } = data;
const { train_dataset, test_dataset, id } = data;
train_dataset.forEach((item) => {
item.id = `$DTrain_${item.dataset_id}_${item.dataset_version}`;
item.model_type = NodeType.trainDataset;
item.style = getStyle(NodeType.trainDataset);
item.id = `$DTrain_${id}_${item.dataset_id}_${item.dataset_version}`;
item.model_type = NodeType.TrainDataset;
item.style = getStyle(NodeType.TrainDataset);
});
test_dataset.forEach((item) => {
item.id = `$DTest_${item.dataset_id}_${item.dataset_version}`;
item.model_type = NodeType.testDataset;
item.style = getStyle(NodeType.testDataset);
item.id = `$DTest_${id}_${item.dataset_id}_${item.dataset_version}`;
item.model_type = NodeType.TestDataset;
item.style = getStyle(NodeType.TestDataset);
});
datasetNodes.length = 0;
@ -243,7 +278,7 @@ const addDatasetDependency = (
fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize);
const half = len / 2 - 0.5;
node.x = currentNode.x! - (half - index) * (ellipseWidth + 20);
node.x = currentNode.x! - (half - index) * (ellipseWidth + datasetHGap);
node.y = currentNode.y! - nodeHeight - vGap;
nodes.push(node);
datasetNodes.push(node);
@ -264,14 +299,14 @@ const addProjectDependency = (
nodes: NodeConfig[],
edges: EdgeConfig[],
) => {
const { project_dependency } = data;
const { project_dependency, id } = data;
if (project_dependency?.url) {
const node = { ...project_dependency };
node.id = `$P_${node.url}_${node.branch}`;
node.model_type = NodeType.project;
node.id = `$P_${id}_${node.url}_${node.branch}`;
node.model_type = NodeType.Project;
node.type = 'rect';
node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize);
node.style = getStyle(NodeType.project);
node.style = getStyle(NodeType.Project);
node.style.radius = nodeHeight / 2;
node.x = currentNode.x;
node.y = currentNode.y! + nodeHeight + vGap;
@ -331,3 +366,49 @@ function adjustDatasetPosition(node: NodeConfig) {
});
}
}
// 层级遍历树结构
export function traverseHierarchically(data: ModelDepsData | undefined): ModelDepsData[] {
if (!data) return [];
let level = 0;
data.level = level;
const result: ModelDepsData[] = [data];
let index = 0;
while (index < result.length) {
const item = result[index];
if (item.children) {
item.children.forEach((child) => {
child.level = item.level + 1;
result.push(child);
});
}
index++;
}
return result;
}
// 找到同层次的下一个节点
export function getSameHierarchyNextNode(node: ModelDepsData, nodes: ModelDepsData[]) {
const index = nodes.findIndex((item) => item.id === node.id);
if (index >= 0 && index < nodes.length - 1) {
const nextNode = nodes[index + 1];
if (nextNode.level === node.level) {
return nextNode;
}
}
return null;
}
// 得到层级的宽度
export function getHierarchyWidth(level: number, nodes: ModelDepsData[]) {
const hierarchyNodes = nodes
.filter((item) => item.level === level && item.expanded === true)
.sort((a, b) => b.datasetLen - a.datasetLen);
const first = hierarchyNodes[0];
if (first) {
return Math.max(((first.datasetLen - 1) * (nodeWidth + datasetHGap)) / 2, 0);
}
return 0;
}

View File

@ -22,7 +22,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) {
};
const gotoModelPage = () => {
if (data.model_type === NodeType.current) {
if (data.model_type === NodeType.Current) {
return;
}
if (data.current_model_id === resourceId) {
@ -39,7 +39,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) {
<div>
<div className={styles['node-tooltips__row']}>
<span className={styles['node-tooltips__row__title']}></span>
{data.model_type === NodeType.current ? (
{data.model_type === NodeType.Current ? (
<span className={styles['node-tooltips__row__value']}>
{data.model_version_dependcy_vo?.name || '--'}
</span>
@ -199,14 +199,14 @@ function NodeTooltips({
if (!data) return null;
let Component = null;
const { model_type } = data;
if (model_type === NodeType.testDataset || model_type === NodeType.trainDataset) {
if (model_type === NodeType.TestDataset || model_type === NodeType.TrainDataset) {
Component = <DatasetInfo data={data} />;
} else if (model_type === NodeType.project) {
} else if (model_type === NodeType.Project) {
Component = <ProjectInfo data={data} />;
} else if (
model_type === NodeType.children ||
model_type === NodeType.parent ||
model_type === NodeType.current
model_type === NodeType.Children ||
model_type === NodeType.Parent ||
model_type === NodeType.Current
) {
Component = <ModelInfo resourceId={resourceId} data={data} onVersionChange={onVersionChange} />;
}

View File

@ -223,7 +223,7 @@ function ModelDeployment() {
{
title: '操作',
dataIndex: 'operation',
width: 350,
width: 250,
key: 'operation',
render: (_: any, record: ModelDeploymentData) => (
<div>