feat: 模型演化添加点击节点展开数据集和项目
This commit is contained in:
parent
e88cc9e197
commit
2cc70b1f00
|
@ -1,9 +1,15 @@
|
|||
/*
|
||||
* @Author: 赵伟
|
||||
* @Date: 2024-06-07 11:24:10
|
||||
* @Description: 模型演化
|
||||
*/
|
||||
|
||||
import { useEffectWhen } from '@/hooks';
|
||||
import { ResourceVersionData } from '@/pages/Dataset/config';
|
||||
import { getModelAtlasReq } from '@/services/dataset/index.js';
|
||||
import themes from '@/styles/theme.less';
|
||||
import { to } from '@/utils/promise';
|
||||
import G6, { G6GraphEvent, Graph } from '@antv/g6';
|
||||
import G6, { G6GraphEvent, Graph, INode } from '@antv/g6';
|
||||
// @ts-ignore
|
||||
import { Flex, Select } from 'antd';
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
@ -11,7 +17,15 @@ import GraphLegend from '../GraphLegend';
|
|||
import NodeTooltips from '../NodeTooltips';
|
||||
import styles from './index.less';
|
||||
import type { ModelDepsData, ProjectDependency, TrainDataset } from './utils';
|
||||
import { getGraphData, nodeFontSize, nodeHeight, nodeWidth, normalizeTreeData } from './utils';
|
||||
import {
|
||||
NodeType,
|
||||
getGraphData,
|
||||
nodeFontSize,
|
||||
nodeHeight,
|
||||
nodeWidth,
|
||||
normalizeTreeData,
|
||||
traverseHierarchically,
|
||||
} from './utils';
|
||||
|
||||
type modeModelEvolutionProps = {
|
||||
resourceId: number;
|
||||
|
@ -37,6 +51,8 @@ function ModelEvolution({
|
|||
const [hoverNodeData, setHoverNodeData] = useState<
|
||||
ModelDepsData | ProjectDependency | TrainDataset | undefined
|
||||
>(undefined);
|
||||
const apiData = useRef<ModelDepsData | undefined>(undefined); // 接口返回的树形结构
|
||||
const hierarchyNodes = useRef<ModelDepsData[]>([]); // 层级迭代树形结构,得到的节点列表
|
||||
|
||||
useEffect(() => {
|
||||
initGraph();
|
||||
|
@ -111,18 +127,7 @@ function ModelEvolution({
|
|||
},
|
||||
},
|
||||
modes: {
|
||||
default: [
|
||||
'drag-canvas',
|
||||
'zoom-canvas',
|
||||
// {
|
||||
// type: 'collapse-expand',
|
||||
// onChange(item?: Item, collapsed?: boolean) {
|
||||
// const data = item!.getModel();
|
||||
// data.collapsed = collapsed;
|
||||
// return true;
|
||||
// },
|
||||
// },
|
||||
],
|
||||
default: ['drag-canvas', 'zoom-canvas'],
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -161,11 +166,26 @@ function ModelEvolution({
|
|||
});
|
||||
|
||||
graph.on('node:click', (e: G6GraphEvent) => {
|
||||
const nodeItem = e.item;
|
||||
const nodeItem = e.item as INode;
|
||||
const model = nodeItem.getModel() as ModelDepsData | ProjectDependency | TrainDataset;
|
||||
const { model_type } = model;
|
||||
switch (model_type) {
|
||||
if (
|
||||
model_type === NodeType.Project ||
|
||||
model_type === NodeType.TrainDataset ||
|
||||
model_type === NodeType.TestDataset ||
|
||||
!apiData.current ||
|
||||
!hierarchyNodes.current
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setShowNodeTooltip(false);
|
||||
setEnterTooltip(false);
|
||||
toggleExpended(model.id);
|
||||
const graphData = getGraphData(apiData.current, hierarchyNodes.current);
|
||||
graph.data(graphData);
|
||||
graph.render();
|
||||
graph.fitView();
|
||||
});
|
||||
|
||||
// 鼠标滚轮缩放时,隐藏 tooltip
|
||||
|
@ -175,6 +195,17 @@ function ModelEvolution({
|
|||
});
|
||||
};
|
||||
|
||||
// toggle 展开
|
||||
const toggleExpended = (id: string) => {
|
||||
const nodes = hierarchyNodes.current;
|
||||
for (const node of nodes) {
|
||||
if (node.id === id) {
|
||||
node.expanded = !node.expanded;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleTooltipsMouseEnter = () => {
|
||||
setEnterTooltip(true);
|
||||
};
|
||||
|
@ -192,7 +223,9 @@ function ModelEvolution({
|
|||
const [res] = await to(getModelAtlasReq(params));
|
||||
if (res && res.data) {
|
||||
const data = normalizeTreeData(res.data);
|
||||
const graphData = getGraphData(data);
|
||||
apiData.current = data;
|
||||
hierarchyNodes.current = traverseHierarchically(data);
|
||||
const graphData = getGraphData(data, hierarchyNodes.current);
|
||||
|
||||
graph.data(graphData);
|
||||
graph.render();
|
||||
|
|
|
@ -6,21 +6,22 @@ import Hierarchy from '@antv/hierarchy';
|
|||
export const nodeWidth = 90;
|
||||
export const nodeHeight = 40;
|
||||
export const vGap = nodeHeight + 20;
|
||||
export const hGap = nodeWidth;
|
||||
export const hGap = nodeHeight + 20;
|
||||
export const ellipseWidth = nodeWidth;
|
||||
export const labelPadding = 30;
|
||||
export const nodeFontSize = 8;
|
||||
export const datasetHGap = 20;
|
||||
|
||||
// 数据集节点
|
||||
const datasetNodes: NodeConfig[] = [];
|
||||
|
||||
export enum NodeType {
|
||||
current = 'current',
|
||||
parent = 'parent',
|
||||
children = 'children',
|
||||
project = 'project',
|
||||
trainDataset = 'trainDataset',
|
||||
testDataset = 'testDataset',
|
||||
Current = 'Current', // 当前模型
|
||||
Parent = 'Parent', // 父模型
|
||||
Children = 'Children', // 子模型
|
||||
Project = 'Project', // 项目
|
||||
TrainDataset = 'TrainDataset', // 训练数据集
|
||||
TestDataset = 'TestDataset', // 测试数据集
|
||||
}
|
||||
|
||||
export type Rect = {
|
||||
|
@ -40,14 +41,14 @@ export interface TrainDataset extends NodeConfig {
|
|||
dataset_id: number;
|
||||
dataset_name: string;
|
||||
dataset_version: string;
|
||||
model_type: NodeType.testDataset | NodeType.trainDataset;
|
||||
model_type: NodeType.TestDataset | NodeType.TrainDataset;
|
||||
}
|
||||
|
||||
export interface ProjectDependency extends NodeConfig {
|
||||
url: string;
|
||||
name: string;
|
||||
branch: string;
|
||||
model_type: NodeType.project;
|
||||
model_type: NodeType.Project;
|
||||
}
|
||||
|
||||
export type ModalDetail = {
|
||||
|
@ -66,9 +67,9 @@ export interface ModelDepsAPIData {
|
|||
version: string;
|
||||
workflow_id: number;
|
||||
exp_ins_id: number;
|
||||
model_type: NodeType.children | NodeType.current | NodeType.parent;
|
||||
model_type: NodeType.Children | NodeType.Current | NodeType.Parent;
|
||||
current_model_name: string;
|
||||
project_dependency: ProjectDependency;
|
||||
project_dependency?: ProjectDependency;
|
||||
test_dataset: TrainDataset[];
|
||||
train_dataset: TrainDataset[];
|
||||
train_task: TrainTask;
|
||||
|
@ -79,16 +80,22 @@ export interface ModelDepsAPIData {
|
|||
|
||||
export interface ModelDepsData extends Omit<ModelDepsAPIData, 'children_models'>, TreeGraphData {
|
||||
children: ModelDepsData[];
|
||||
expanded: boolean; // 是否展开
|
||||
level: number; // 层级,从 0 开始
|
||||
datasetLen: number; // 数据集数量
|
||||
}
|
||||
|
||||
// 规范化子数据
|
||||
export function normalizeChildren(data: ModelDepsData[]) {
|
||||
if (Array.isArray(data)) {
|
||||
data.forEach((item) => {
|
||||
item.model_type = NodeType.children;
|
||||
item.model_type = NodeType.Children;
|
||||
item.expanded = false;
|
||||
item.level = 0;
|
||||
item.datasetLen = item.train_dataset.length + item.test_dataset.length;
|
||||
item.id = `$M_${item.current_model_id}_${item.version}`;
|
||||
item.label = getLabel(item);
|
||||
item.style = getStyle(NodeType.children);
|
||||
item.style = getStyle(NodeType.Children);
|
||||
normalizeChildren(item.children);
|
||||
});
|
||||
}
|
||||
|
@ -111,22 +118,22 @@ export function getLabel(node: ModelDepsData | ModelDepsAPIData) {
|
|||
export function getStyle(model_type: NodeType) {
|
||||
let fill = '';
|
||||
switch (model_type) {
|
||||
case NodeType.current:
|
||||
case NodeType.Current:
|
||||
fill = 'l(0) 0:#72a1ff 1:#1664ff';
|
||||
break;
|
||||
case NodeType.parent:
|
||||
case NodeType.Parent:
|
||||
fill = 'l(0) 0:#93dfd1 1:#43c9b1';
|
||||
break;
|
||||
case NodeType.children:
|
||||
case NodeType.Children:
|
||||
fill = 'l(0) 0:#72b4ff 1:#169aff';
|
||||
break;
|
||||
case NodeType.project:
|
||||
case NodeType.Project:
|
||||
fill = 'l(0) 0:#b3a9ff 1:#8981ff';
|
||||
break;
|
||||
case NodeType.trainDataset:
|
||||
case NodeType.TrainDataset:
|
||||
fill = '#a5d878';
|
||||
break;
|
||||
case NodeType.testDataset:
|
||||
case NodeType.TestDataset:
|
||||
fill = '#d8b578';
|
||||
break;
|
||||
default:
|
||||
|
@ -145,11 +152,15 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
|
|||
}) as ModelDepsData;
|
||||
|
||||
// 设置当前模型的数据
|
||||
normalizedData.model_type = NodeType.current;
|
||||
normalizedData.model_type = NodeType.Current;
|
||||
normalizedData.id = `$M_${normalizedData.current_model_id}_${normalizedData.version}`;
|
||||
normalizedData.label = getLabel(normalizedData);
|
||||
normalizedData.style = getStyle(NodeType.current);
|
||||
normalizedData.style = getStyle(NodeType.Current);
|
||||
normalizedData.expanded = true;
|
||||
normalizedData.datasetLen =
|
||||
normalizedData.train_dataset.length + normalizedData.test_dataset.length;
|
||||
normalizeChildren(normalizedData.children as ModelDepsData[]);
|
||||
normalizedData.level = 0;
|
||||
|
||||
// 将 parent_models 转换成树形结构
|
||||
let parent_models = normalizedData.parent_models || [];
|
||||
|
@ -157,10 +168,13 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
|
|||
const parent = parent_models[0];
|
||||
normalizedData = {
|
||||
...parent,
|
||||
model_type: NodeType.parent,
|
||||
expanded: false,
|
||||
level: 0,
|
||||
datasetLen: parent.train_dataset.length + parent.test_dataset.length,
|
||||
model_type: NodeType.Parent,
|
||||
id: `$M_${parent.current_model_id}_${parent.version}`,
|
||||
label: getLabel(parent),
|
||||
style: getStyle(NodeType.parent),
|
||||
style: getStyle(NodeType.Parent),
|
||||
children: [
|
||||
{
|
||||
...normalizedData,
|
||||
|
@ -174,13 +188,34 @@ export function normalizeTreeData(apiData: ModelDepsAPIData): ModelDepsData {
|
|||
}
|
||||
|
||||
// 将树形数据,使用 Hierarchy 进行布局,计算出坐标,然后转换成 G6 的数据
|
||||
export function getGraphData(data: ModelDepsData): GraphData {
|
||||
export function getGraphData(data: ModelDepsData, hierarchyNodes: ModelDepsData[]): GraphData {
|
||||
const config = {
|
||||
direction: 'LR',
|
||||
getHeight: () => nodeHeight,
|
||||
getWidth: () => nodeWidth,
|
||||
getVGap: () => vGap / 2,
|
||||
getHGap: () => hGap / 2,
|
||||
getVGap: (node: NodeConfig) => {
|
||||
const model = node as ModelDepsData;
|
||||
const { model_type, expanded, project_dependency } = model;
|
||||
if (model_type === NodeType.Current || model_type === NodeType.Parent) {
|
||||
return vGap / 2;
|
||||
}
|
||||
const selfGap = expanded && project_dependency?.url ? nodeHeight + vGap : 0;
|
||||
const nextNode = getSameHierarchyNextNode(model, hierarchyNodes);
|
||||
if (!nextNode) {
|
||||
return vGap / 2;
|
||||
}
|
||||
const nextGap = nextNode.expanded === true && nextNode.datasetLen > 0 ? nodeHeight + vGap : 0;
|
||||
return (selfGap + nextGap + vGap) / 2;
|
||||
},
|
||||
getHGap: (node: NodeConfig) => {
|
||||
const model = node as ModelDepsData;
|
||||
return (
|
||||
(getHierarchyWidth(model.level, hierarchyNodes) +
|
||||
getHierarchyWidth(model.level + 1, hierarchyNodes) +
|
||||
hGap) /
|
||||
2
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
// 树形布局计算出坐标
|
||||
|
@ -191,11 +226,11 @@ export function getGraphData(data: ModelDepsData): GraphData {
|
|||
Util.traverseTree(treeLayoutData, (node: NodeConfig, parent: NodeConfig) => {
|
||||
const data = node.data as ModelDepsData;
|
||||
// 当前模型显示数据集和项目
|
||||
if (data.model_type === NodeType.current) {
|
||||
if (data.expanded === true) {
|
||||
addDatasetDependency(data, node, nodes, edges);
|
||||
addProjectDependency(data, node, nodes, edges);
|
||||
} else if (data.model_type === NodeType.children) {
|
||||
adjustDatasetPosition(node);
|
||||
} else if (data.model_type === NodeType.Children) {
|
||||
// adjustDatasetPosition(node);
|
||||
}
|
||||
nodes.push({
|
||||
...data,
|
||||
|
@ -219,16 +254,16 @@ const addDatasetDependency = (
|
|||
nodes: NodeConfig[],
|
||||
edges: EdgeConfig[],
|
||||
) => {
|
||||
const { train_dataset, test_dataset } = data;
|
||||
const { train_dataset, test_dataset, id } = data;
|
||||
train_dataset.forEach((item) => {
|
||||
item.id = `$DTrain_${item.dataset_id}_${item.dataset_version}`;
|
||||
item.model_type = NodeType.trainDataset;
|
||||
item.style = getStyle(NodeType.trainDataset);
|
||||
item.id = `$DTrain_${id}_${item.dataset_id}_${item.dataset_version}`;
|
||||
item.model_type = NodeType.TrainDataset;
|
||||
item.style = getStyle(NodeType.TrainDataset);
|
||||
});
|
||||
test_dataset.forEach((item) => {
|
||||
item.id = `$DTest_${item.dataset_id}_${item.dataset_version}`;
|
||||
item.model_type = NodeType.testDataset;
|
||||
item.style = getStyle(NodeType.testDataset);
|
||||
item.id = `$DTest_${id}_${item.dataset_id}_${item.dataset_version}`;
|
||||
item.model_type = NodeType.TestDataset;
|
||||
item.style = getStyle(NodeType.TestDataset);
|
||||
});
|
||||
|
||||
datasetNodes.length = 0;
|
||||
|
@ -243,7 +278,7 @@ const addDatasetDependency = (
|
|||
fittingString(node.dataset_version, ellipseWidth - labelPadding, nodeFontSize);
|
||||
|
||||
const half = len / 2 - 0.5;
|
||||
node.x = currentNode.x! - (half - index) * (ellipseWidth + 20);
|
||||
node.x = currentNode.x! - (half - index) * (ellipseWidth + datasetHGap);
|
||||
node.y = currentNode.y! - nodeHeight - vGap;
|
||||
nodes.push(node);
|
||||
datasetNodes.push(node);
|
||||
|
@ -264,14 +299,14 @@ const addProjectDependency = (
|
|||
nodes: NodeConfig[],
|
||||
edges: EdgeConfig[],
|
||||
) => {
|
||||
const { project_dependency } = data;
|
||||
const { project_dependency, id } = data;
|
||||
if (project_dependency?.url) {
|
||||
const node = { ...project_dependency };
|
||||
node.id = `$P_${node.url}_${node.branch}`;
|
||||
node.model_type = NodeType.project;
|
||||
node.id = `$P_${id}_${node.url}_${node.branch}`;
|
||||
node.model_type = NodeType.Project;
|
||||
node.type = 'rect';
|
||||
node.label = fittingString(node.name, nodeWidth - labelPadding, nodeFontSize);
|
||||
node.style = getStyle(NodeType.project);
|
||||
node.style = getStyle(NodeType.Project);
|
||||
node.style.radius = nodeHeight / 2;
|
||||
node.x = currentNode.x;
|
||||
node.y = currentNode.y! + nodeHeight + vGap;
|
||||
|
@ -331,3 +366,49 @@ function adjustDatasetPosition(node: NodeConfig) {
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 层级遍历树结构
|
||||
export function traverseHierarchically(data: ModelDepsData | undefined): ModelDepsData[] {
|
||||
if (!data) return [];
|
||||
let level = 0;
|
||||
data.level = level;
|
||||
const result: ModelDepsData[] = [data];
|
||||
let index = 0;
|
||||
|
||||
while (index < result.length) {
|
||||
const item = result[index];
|
||||
if (item.children) {
|
||||
item.children.forEach((child) => {
|
||||
child.level = item.level + 1;
|
||||
result.push(child);
|
||||
});
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// 找到同层次的下一个节点
|
||||
export function getSameHierarchyNextNode(node: ModelDepsData, nodes: ModelDepsData[]) {
|
||||
const index = nodes.findIndex((item) => item.id === node.id);
|
||||
if (index >= 0 && index < nodes.length - 1) {
|
||||
const nextNode = nodes[index + 1];
|
||||
if (nextNode.level === node.level) {
|
||||
return nextNode;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// 得到层级的宽度
|
||||
export function getHierarchyWidth(level: number, nodes: ModelDepsData[]) {
|
||||
const hierarchyNodes = nodes
|
||||
.filter((item) => item.level === level && item.expanded === true)
|
||||
.sort((a, b) => b.datasetLen - a.datasetLen);
|
||||
const first = hierarchyNodes[0];
|
||||
if (first) {
|
||||
return Math.max(((first.datasetLen - 1) * (nodeWidth + datasetHGap)) / 2, 0);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) {
|
|||
};
|
||||
|
||||
const gotoModelPage = () => {
|
||||
if (data.model_type === NodeType.current) {
|
||||
if (data.model_type === NodeType.Current) {
|
||||
return;
|
||||
}
|
||||
if (data.current_model_id === resourceId) {
|
||||
|
@ -39,7 +39,7 @@ function ModelInfo({ resourceId, data, onVersionChange }: ModelInfoProps) {
|
|||
<div>
|
||||
<div className={styles['node-tooltips__row']}>
|
||||
<span className={styles['node-tooltips__row__title']}>模型名称:</span>
|
||||
{data.model_type === NodeType.current ? (
|
||||
{data.model_type === NodeType.Current ? (
|
||||
<span className={styles['node-tooltips__row__value']}>
|
||||
{data.model_version_dependcy_vo?.name || '--'}
|
||||
</span>
|
||||
|
@ -199,14 +199,14 @@ function NodeTooltips({
|
|||
if (!data) return null;
|
||||
let Component = null;
|
||||
const { model_type } = data;
|
||||
if (model_type === NodeType.testDataset || model_type === NodeType.trainDataset) {
|
||||
if (model_type === NodeType.TestDataset || model_type === NodeType.TrainDataset) {
|
||||
Component = <DatasetInfo data={data} />;
|
||||
} else if (model_type === NodeType.project) {
|
||||
} else if (model_type === NodeType.Project) {
|
||||
Component = <ProjectInfo data={data} />;
|
||||
} else if (
|
||||
model_type === NodeType.children ||
|
||||
model_type === NodeType.parent ||
|
||||
model_type === NodeType.current
|
||||
model_type === NodeType.Children ||
|
||||
model_type === NodeType.Parent ||
|
||||
model_type === NodeType.Current
|
||||
) {
|
||||
Component = <ModelInfo resourceId={resourceId} data={data} onVersionChange={onVersionChange} />;
|
||||
}
|
||||
|
|
|
@ -223,7 +223,7 @@ function ModelDeployment() {
|
|||
{
|
||||
title: '操作',
|
||||
dataIndex: 'operation',
|
||||
width: 350,
|
||||
width: 250,
|
||||
key: 'operation',
|
||||
render: (_: any, record: ModelDeploymentData) => (
|
||||
<div>
|
||||
|
|
Loading…
Reference in New Issue