Add new model in doc table of content (#25148)

This commit is contained in:
Sylvain Gugger 2023-07-27 13:41:50 -04:00 committed by GitHub
parent e93103632b
commit 400e76ef11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 53 additions and 0 deletions

View File

@ -23,6 +23,8 @@ from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union
import yaml
from ..models import auto as auto_module
from ..models.auto.configuration_auto import model_type_to_module_name
from ..utils import is_flax_available, is_tf_available, is_torch_available, logging
@ -1268,6 +1270,56 @@ def duplicate_doc_file(
f.write("\n".join(new_blocks))
def insert_model_in_doc_toc(old_model_patterns, new_model_patterns):
"""
Insert the new model in the doc TOC, in the same section as the old model.
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
"""
toc_file = REPO_PATH / "docs" / "source" / "en" / "_toctree.yml"
with open(toc_file, "r", encoding="utf8") as f:
content = yaml.safe_load(f)
# Get to the model API doc
api_idx = 0
while content[api_idx]["title"] != "API":
api_idx += 1
api_doc = content[api_idx]["sections"]
model_idx = 0
while api_doc[model_idx]["title"] != "Models":
model_idx += 1
model_doc = api_doc[model_idx]["sections"]
# Find the base model in the Toc
old_model_type = old_model_patterns.model_type
section_idx = 0
while section_idx < len(model_doc):
sections = [entry["local"] for entry in model_doc[section_idx]["sections"]]
if f"model_doc/{old_model_type}" in sections:
break
section_idx += 1
if section_idx == len(model_doc):
old_model = old_model_patterns.model_name
new_model = new_model_patterns.model_name
print(f"Did not find {old_model} in the table of content, so you will need to add {new_model} manually.")
return
# Add the new model in the same toc
toc_entry = {"local": f"model_doc/{new_model_patterns.model_type}", "title": new_model_patterns.model_name}
model_doc[section_idx]["sections"].append(toc_entry)
model_doc[section_idx]["sections"] = sorted(model_doc[section_idx]["sections"], key=lambda s: s["title"].lower())
api_doc[model_idx]["sections"] = model_doc
content[api_idx]["sections"] = api_doc
with open(toc_file, "w", encoding="utf-8") as f:
f.write(yaml.dump(content, allow_unicode=True))
def create_new_model_like(
model_type: str,
new_model_patterns: ModelPatterns,
@ -1407,6 +1459,7 @@ def create_new_model_like(
# 5. Add doc file
doc_file = REPO_PATH / "docs" / "source" / "en" / "model_doc" / f"{old_model_patterns.model_type}.md"
duplicate_doc_file(doc_file, old_model_patterns, new_model_patterns, frameworks=frameworks)
insert_model_in_doc_toc(old_model_patterns, new_model_patterns)
# 6. Warn the user for duplicate patterns
if old_model_patterns.model_type == old_model_patterns.checkpoint: