Add new model in doc table of content (#25148)
This commit is contained in:
parent
e93103632b
commit
400e76ef11
|
@ -23,6 +23,8 @@ from itertools import chain
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from ..models import auto as auto_module
|
||||
from ..models.auto.configuration_auto import model_type_to_module_name
|
||||
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
|
||||
|
@ -1268,6 +1270,56 @@ def duplicate_doc_file(
|
|||
f.write("\n".join(new_blocks))
|
||||
|
||||
|
||||
def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
|
||||
"""
|
||||
Insert the new model in the doc TOC, in the same section as the old model.
|
||||
|
||||
Args:
|
||||
old_model_patterns (`ModelPatterns`): The patterns for the old model.
|
||||
new_model_patterns (`ModelPatterns`): The patterns for the new model.
|
||||
"""
|
||||
toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
|
||||
with open(toc_file, "r", encoding="utf8") as f:
|
||||
content = yaml.safe_load(f)
|
||||
|
||||
# Get to the model API doc
|
||||
api_idx = 0
|
||||
while content[api_idx]["title"] != "API":
|
||||
api_idx += 1
|
||||
api_doc = content[api_idx]["sections"]
|
||||
|
||||
model_idx = 0
|
||||
while api_doc[model_idx]["title"] != "Models":
|
||||
model_idx += 1
|
||||
model_doc = api_doc[model_idx]["sections"]
|
||||
|
||||
# Find the base model in the Toc
|
||||
old_model_type = old_model_patterns.model_type
|
||||
section_idx = 0
|
||||
while section_idx < len(model_doc):
|
||||
sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
|
||||
if f"model_doc/{old_model_type}" in sections:
|
||||
break
|
||||
|
||||
section_idx += 1
|
||||
|
||||
if section_idx == len(model_doc):
|
||||
old_model = old_model_patterns.model_name
|
||||
new_model = new_model_patterns.model_name
|
||||
print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
|
||||
return
|
||||
|
||||
# Add the new model in the same toc
|
||||
toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
|
||||
model_doc[section_idx]["sections"].append(toc_entry)
|
||||
model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
|
||||
api_doc[model_idx]["sections"] = model_doc
|
||||
content[api_idx]["sections"] = api_doc
|
||||
|
||||
with open(toc_file, "w", encoding="utf-8") as f:
|
||||
f.write(yaml.dump(content, allow_unicode=True))
|
||||
|
||||
|
||||
def create_new_model_like(
|
||||
model_type: str,
|
||||
new_model_patterns: ModelPatterns,
|
||||
|
@ -1407,6 +1459,7 @@ def create_new_model_like(
|
|||
# 5. Add doc file
|
||||
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
|
||||
duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
|
||||
insert_model_in_doc_toc(old_model_patterns, new_model_patterns)
|
||||
|
||||
# 6. Warn the user for duplicate patterns
|
||||
if old_model_patterns.model_type == old_model_patterns.checkpoint:
|
||||
|
|
Loading…
Reference in New Issue