Python 脚本:从 YAML 条目中提取公共元素到可派生的基元素

这个脚本是为优化 kicad-footprint-generator 的 YAML 定义而开发的。

它会扫描一个 footprint YAML 文件,找出在大多数条目中相同的值,并将它们移入一个共享的基元素中。每个单独的 footprint 只保留与基元素实际不同的参数,从而缩短文件长度并减少重复。

例如,给定这样一个最小输入:

original.yaml
Vertical:
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05

Horizontal:
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05
    special_option: true

该脚本可以将公共值提取到基元素中,并将文件大致重写为:

processed.yaml
base: &base
    pad_width: 1.0
    pad_height: 2.0
    solder_mask_margin: 0.05

Vertical:
    <<: *base

Horizontal:
    <<: *base
    special_option: true

这基本上就是它对大得多的 KiCad 生成器 YAML 所做的事情。

yaml_extract_common_base.py
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2025 Uli Köhler <gitlab@techoverflow.net>
# SPDX-License-Identifier: CC0-1.0
"""从 KiCad 生成器 YAML 文件中提取并合并公共值。"""

from __future__ import annotations

import argparse
import copy
import importlib
import io
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

try:
    _ruamel_yaml = importlib.import_module("ruamel.yaml")
    _ruamel_comments = importlib.import_module("ruamel.yaml.comments")
except ImportError as exc:  # pragma: no cover - dependency missing at runtime
    raise SystemExit(
        "The ruamel.yaml package is required. Install it via 'pip install ruamel.yaml'."
    ) from exc

YAML = _ruamel_yaml.YAML
CommentedMap = _ruamel_comments.CommentedMap
CommentedSeq = _ruamel_comments.CommentedSeq

BASE_COMMENT_MARKER = "Base configuration extracted by YAMLCommon"


RT_YAML = YAML(typ="rt")
RT_YAML.preserve_quotes = True
RT_YAML.allow_duplicate_keys = True
RT_YAML.width = 4096
RT_YAML.indent(mapping=2, sequence=4, offset=2)


@dataclass
class Entry:
    """表示自定义 YAML 文件中的单个 footprint 条目。"""

    leading_lines: List[str]
    separator_lines: List[str]
    name_line: Optional[str]
    after_name_lines: List[str]
    root_key: str
    inner_map: CommentedMap
    outer_map: CommentedMap
    inherits_base: bool = False
    defines_anchor: bool = False


@dataclass
class ParsedFile:
    """保存解析文件的结构组件。"""

    prefix_lines: List[str]
    entries: List[Entry]
    suffix_lines: List[str]
    existing_base: Optional[CommentedMap] = None


def is_block_start(line: str) -> bool:
    """如果 *line* 标记了一个 footprint 块的开始,则返回 True。"""

    stripped = line.strip()
    if not stripped or line.startswith(" "):
        return False
    if ":" not in stripped:
        return False
    key = stripped.split(":", 1)[0].strip()
    return bool(key)


def is_name_line(line: str) -> bool:
    """识别 footprint 名称行的启发式方法。"""

    stripped = line.strip()
    if not stripped or ":" in stripped:
        return False
    if stripped.startswith("#") or stripped.startswith("-"):
        return False
    return line.startswith("  ")


def parse_file(text: str, anchor_name: str) -> ParsedFile:
    """将非标准 YAML 文件解析为结构化条目。"""

    lines = text.splitlines(keepends=True)
    idx = 0
    prefix_lines: List[str] = []
    suffix_lines: List[str] = []
    entries: List[Entry] = []
    first_block = True

    while True:
        interim: List[str] = []
        while idx < len(lines) and not is_block_start(lines[idx]):
            interim.append(lines[idx])
            idx += 1

        if idx >= len(lines):
            if first_block:
                prefix_lines = interim
            else:
                suffix_lines = interim
            break

        if first_block:
            prefix_lines = interim
            leading_lines: List[str] = []
            first_block = False
        else:
            leading_lines = interim

        block_lines: List[str] = []
        while idx < len(lines) and lines[idx].strip() != "":
            block_lines.append(lines[idx])
            idx += 1
        defines_anchor = any(f"&{anchor_name}" in line for line in block_lines)
        block_lines, inherits_base = strip_merge_marker(block_lines, anchor_name)
        block_text = "".join(block_lines)
        if not block_text.strip():
            raise ValueError("遇到空块;文件格式可能不受支持。")

        separator_lines: List[str] = []
        while idx < len(lines) and lines[idx].strip() == "":
            separator_lines.append(lines[idx])
            idx += 1
        while idx < len(lines) and lines[idx].lstrip().startswith("#"):
            separator_lines.append(lines[idx])
            idx += 1

        name_line: Optional[str] = None
        if idx < len(lines) and is_name_line(lines[idx]):
            name_line = lines[idx]
            idx += 1

        after_name_lines: List[str] = []
        while idx < len(lines) and not is_block_start(lines[idx]):
            after_name_lines.append(lines[idx])
            idx += 1

        data = RT_YAML.load(block_text)
        if not isinstance(data, CommentedMap) or len(data) != 1:
            raise ValueError("每个块必须是具有单个根键(例如 Vertical)的映射。")

        root_key = next(iter(data))
        inner_map = data[root_key]
        if not isinstance(inner_map, CommentedMap):
            raise ValueError("顶层条目必须映射到一个参数字典。")

        entries.append(
            Entry(
                leading_lines=leading_lines,
                separator_lines=separator_lines,
                name_line=name_line,
                after_name_lines=after_name_lines,
                root_key=root_key,
                inner_map=inner_map,
                outer_map=data,
                inherits_base=inherits_base,
                defines_anchor=defines_anchor,
            )
        )

    return ParsedFile(prefix_lines=prefix_lines, entries=entries, suffix_lines=suffix_lines)


def strip_merge_marker(block_lines: List[str], anchor_name: str) -> tuple[List[str], bool]:
    """从 YAML 块中移除对共享锚点的合并引用。"""

    filtered: List[str] = []
    removed = False
    target = f"<<: *{anchor_name}"
    for line in block_lines:
        if line.strip() == target:
            removed = True
            continue
        filtered.append(line)
    return filtered, removed


def strip_generated_base(parsed: ParsedFile, anchor_name: str) -> ParsedFile:
    """移除之前生成的基块及相关注释。"""

    clean_prefix = [line for line in parsed.prefix_lines if BASE_COMMENT_MARKER not in line]
    filtered_entries: List[Entry] = []
    base_map: Optional[CommentedMap] = parsed.existing_base
    for entry in parsed.entries:
        if entry.defines_anchor:
            base_map = entry.inner_map
            continue
        filtered_entries.append(entry)
    return ParsedFile(
        prefix_lines=clean_prefix,
        entries=filtered_entries,
        suffix_lines=parsed.suffix_lines,
        existing_base=base_map,
    )


def ensure_base_inheritance(parsed: ParsedFile) -> None:
    """如果条目缺少基键,则将其标记为继承基元素。"""

    if not parsed.existing_base:
        return
    for entry in parsed.entries:
        if entry.inherits_base:
            continue
        if not entry_supports_base(entry.inner_map, parsed.existing_base):
            entry.inherits_base = True


def normalize(value):
    """将 ruamel 节点转换为普通 Python 对象以进行相等性检查。"""

    if isinstance(value, CommentedMap):
        return {k: normalize(v) for k, v in value.items()}
    if isinstance(value, CommentedSeq):
        return [normalize(v) for v in value]
    return value


def merged_with_base(base_map: CommentedMap, overrides: CommentedMap) -> CommentedMap:
    """返回 *base_map* 的深拷贝,并用 *overrides* 更新。"""

    result = copy.deepcopy(base_map)
    for key, value in overrides.items():
        if (
            key in result
            and isinstance(result[key], CommentedMap)
            and isinstance(value, CommentedMap)
        ):
            result[key] = merged_with_base(result[key], value)
        else:
            result[key] = copy.deepcopy(value)
    return result


def values_equal(lhs, rhs) -> bool:
    """忽略 ruamel 特定元数据的深度相等性检查。"""

    return normalize(lhs) == normalize(rhs)


def _round_number(value, digits: Optional[int]):
    """如果请求,将数值 *value* 四舍五入到 *digits* 位小数。

    如果 *digits* 为 None,则原样返回该值。非数值
    也原样返回。
    """

    if digits is None:
        return value
    try:
        # 处理 int/float,同时保持其他类型不变
        if isinstance(value, (int, float)):
            return round(value, digits)
    except TypeError:
        pass
    return value


def _round_node(node, digits: Optional[int]):
    """递归地就地四舍五入 *node* 中的所有数值标量。"""

    if digits is None:
        return node
    if isinstance(node, CommentedMap):
        for key, val in list(node.items()):
            node[key] = _round_node(val, digits)
        return node
    if isinstance(node, CommentedSeq):
        for idx, val in enumerate(list(node)):
            node[idx] = _round_node(val, digits)
        return node
    return _round_number(node, digits)



def collect_common(
    entries: List[Entry],
    threshold: float,
    fallback_base: Optional[CommentedMap] = None,
) -> CommentedMap:
    """返回一个包含条目间公共键的映射。"""

    total = len(entries)
    result = CommentedMap()
    if total == 0:
        return result

    occurrences: dict[str, List[object]] = {}
    for entry in entries:
        items = entry.inner_map.items()
        if entry.inherits_base and fallback_base is not None:
            combined = merged_with_base(fallback_base, entry.inner_map)
            items = combined.items()
        for key, value in items:
            occurrences.setdefault(key, []).append(value)

    for key, values in occurrences.items():
        ratio = len(values) / total
        if ratio < threshold:
            continue
        first = values[0]
        if all(values_equal(first, other) for other in values[1:]):
            result[key] = copy.deepcopy(first)

    return result


def entry_supports_base(entry_map: CommentedMap, base_map: CommentedMap) -> bool:
    """检查 *entry_map* 是否包含 *base_map* 中的所有键。"""

    for key, base_value in base_map.items():
        if key not in entry_map:
            return False
        entry_value = entry_map[key]
        if isinstance(base_value, CommentedMap) and isinstance(entry_value, CommentedMap):
            if not entry_supports_base(entry_value, base_value):
                return False
        elif isinstance(base_value, CommentedSeq) and isinstance(entry_value, CommentedSeq):
            if not values_equal(entry_value, base_value):
                return False
        else:
            if not values_equal(entry_value, base_value):
                return False
    return True


def build_overrides(entry_map: CommentedMap, base_map: CommentedMap) -> CommentedMap:
    """返回 *entry_map* 中与 *base_map* 不同的子集。"""

    overrides = CommentedMap()
    for key, entry_value in entry_map.items():
        if key not in base_map:
            overrides[key] = copy.deepcopy(entry_value)
            continue

        base_value = base_map[key]
        if isinstance(entry_value, CommentedMap) and isinstance(base_value, CommentedMap):
            child = build_overrides(entry_value, base_value)
            if child:
                overrides[key] = child
            continue
        if isinstance(entry_value, CommentedSeq) and isinstance(base_value, CommentedSeq):
            if not values_equal(entry_value, base_value):
                overrides[key] = copy.deepcopy(entry_value)
            continue
        if not values_equal(entry_value, base_value):
            overrides[key] = copy.deepcopy(entry_value)

    return overrides


def dump_map(node: CommentedMap) -> str:
    """将 *node* 渲染回 YAML 文本。"""

    buffer = io.StringIO()
    RT_YAML.dump(node, buffer)
    return buffer.getvalue()


def _supports_color(stream) -> bool:
    """如果 *stream* 可能支持 ANSI 颜色代码,则返回 True。"""

    return bool(getattr(stream, "isatty", lambda: False)()) and os.environ.get("TERM") not in {None, "", "dumb"}


def _colorize(text: str, color_code: str = "31") -> str:
    """如果当前 stderr 支持颜色,则用 ANSI 颜色代码包裹 *text*。"""

    if _supports_color(sys.stderr):
        return f"\033[{color_code}m{text}\033[0m"
    return text


def inject_merge_line(rendered: str, anchor_name: str, indent: str = "  ") -> str:
    """在根键之后紧接插入一个 YAML 合并引用。"""

    lines = rendered.splitlines(keepends=True)
    if not lines:
        return f"{indent}<<: *{anchor_name}\n"
    if not lines[0].endswith("\n"):
        lines[0] += "\n"
    merge_line = f"{indent}<<: *{anchor_name}\n"
    lines.insert(1, merge_line)
    return "".join(lines)


def render_with_merge(root_key: str, overrides: CommentedMap, anchor_name: str) -> str:
    """渲染一个以合并引用开头的条目映射。"""

    if overrides:
        container = CommentedMap()
        container[root_key] = overrides
        return inject_merge_line(dump_map(container), anchor_name)
    return f"{root_key}:\n  <<: *{anchor_name}\n"


def rewrite(
    parsed: ParsedFile,
    base_map: Optional[CommentedMap],
    threshold: float,
    anchor_name: str,
    round_digits: Optional[int],
    target_path: Path,
) -> None:
    """用一个基块后跟各条目的覆盖项来重写 YAML 文件。"""

    fragments: List[str] = []
    fragments.extend(parsed.prefix_lines)

    base_inner = None
    if base_map and len(base_map):
        base_inner = copy.deepcopy(base_map)
        # 如果请求,在锚定之前对共享基元素应用四舍五入。
        if round_digits is not None:
            _round_node(base_inner, round_digits)
        base_inner.yaml_set_anchor(anchor_name, always_dump=True)
        # 生成的 YAML 中基元素必须始终命名为 "base",
        # 无论原始根键是什么。
        base_container = CommentedMap()
        base_container["base"] = base_inner
        base_container.yaml_set_start_comment(
            f"Base configuration extracted by YAMLCommon (threshold={threshold:.2f})",
            indent=0,
        )
        fragments.append(dump_map(base_container))
        fragments.append("\n")

    for entry in parsed.entries:
        fragments.extend(entry.leading_lines)
        # 如果我们有一个生成的基元素,并且该条目要么显式支持
        # 该基元素,要么之前被标记为从基元素继承
        # (例如因为它缺少基键),则始终使用合并引用
        # 来渲染该条目。这确保了每个相对于基元素
        # 移除了值的条目仍然被声明为从基元素派生。
        entry_map_for_write = copy.deepcopy(entry.outer_map)
        if base_inner and (entry_supports_base(entry.inner_map, base_inner) or entry.inherits_base):
            # 相对于(已四舍五入的)基元素构建覆盖项。
            inner_copy = copy.deepcopy(entry.inner_map)
            if round_digits is not None:
                _round_node(inner_copy, round_digits)
            overrides = build_overrides(inner_copy, base_inner)
            # 如果没有覆盖项(即条目与基元素相同),
            # 仍然发出合并引用,以使条目仍然存在。
            fragments.append(render_with_merge(entry.root_key, overrides, anchor_name))
        else:
            # 没有基元素或基元素不支持:对整个条目映射进行四舍五入。
            if round_digits is not None:
                _round_node(entry_map_for_write, round_digits)
            fragments.append(dump_map(entry_map_for_write))
        fragments.extend(entry.separator_lines)
        if entry.name_line:
            fragments.append(entry.name_line if entry.name_line.endswith("\n") else entry.name_line + "\n")
        fragments.extend(entry.after_name_lines)

    fragments.extend(parsed.suffix_lines)

    target_path.write_text("".join(fragments), encoding="utf-8")


def print_base(base_map: Optional[CommentedMap], threshold: float) -> None:
    """在 stdout 上显示提取的基元素。"""

    if not base_map:
        print(f"没有基属性达到阈值({threshold:.2f})。")
        return

    container = CommentedMap()
    container["base"] = base_map
    print(f"提取的基元素(阈值={threshold:.2f}):")
    sys.stdout.write(dump_map(container))


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("yaml_file", type=Path, help="footprint YAML 文件的路径")
    parser.add_argument(
        "-t",
        "--threshold",
        type=float,
        default=0.9,
        help="一个键必须出现的条目比例(默认:0.9)",
    )
    parser.add_argument(
        "-r",
        "--round",
        type=int,
        default=None,
        help=(
            "在写入输出之前,将所有数值四舍五入到指定的小数位数。"
        ),
    )
    parser.add_argument(
        "-o",
        "--output",
        type=Path,
        help=(
            "重写后的 YAML 文件的输出路径(默认:覆盖输入文件,"
            "需要 --overwrite)。"
        ),
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=(
            "允许覆盖原始 YAML 文件;仅当 "
            "--output 设置为原始文件路径时有效。"
        ),
    )
    parser.add_argument(
        "--anchor-name",
        default="base",
        help="用于共享基映射的锚点名称(默认:base)",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    if not 0 < args.threshold <= 1:
        raise SystemExit("阈值必须在 (0, 1] 范围内。")

    text = args.yaml_file.read_text(encoding="utf-8")
    parsed = parse_file(text, args.anchor_name)
    parsed = strip_generated_base(parsed, args.anchor_name)
    ensure_base_inheritance(parsed)
    if not parsed.entries:
        raise SystemExit("在提供的文件中未找到任何条目。")

    common_map = collect_common(parsed.entries, args.threshold, parsed.existing_base)

    if parsed.existing_base:
        base_for_write = copy.deepcopy(parsed.existing_base)
        if common_map:
            additional = build_overrides(common_map, base_for_write)
            if additional:
                base_for_write = merged_with_base(base_for_write, additional)
    else:
        base_for_write = copy.deepcopy(common_map) if common_map else None

    print_base(base_for_write, args.threshold)

    target = args.output or args.yaml_file
    write_requires_overwrite = target.resolve() == args.yaml_file.resolve()
    if write_requires_overwrite and not args.overwrite:
        warning = (
            "未提供 --overwrite;输入文件保持不变,未写入任何输出。"
        )
        print(_colorize(warning), file=sys.stderr)
        return

    rewrite(
        parsed,
        base_for_write if base_for_write else None,
        args.threshold,
        args.anchor_name,
        args.round,
        target,
    )


if __name__ == "__main__":
    main()

Check out similar posts by category: Python