Persist stage context snapshots for run workflows
This commit is contained in:
@@ -5,7 +5,8 @@ import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
DESIGN_AGENT_ROOT = Path(r'D:\\ad-hoc\\kei\\design_agent')
|
||||
|
||||
DESIGN_AGENT_ROOT = Path(r'D:\ad-hoc\kei\design_agent')
|
||||
if str(DESIGN_AGENT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(DESIGN_AGENT_ROOT))
|
||||
|
||||
@@ -30,6 +31,7 @@ from src.pipeline_context import (
|
||||
)
|
||||
from src.renderer import render_slide_from_html
|
||||
from src.slide_measurer import capture_slide_screenshot, measure_rendered_heights
|
||||
|
||||
if not hasattr(html_generator, 'SIDEBAR_PROMPT') and hasattr(html_generator, '_LEGACY_SIDEBAR_PROMPT'):
|
||||
html_generator.SIDEBAR_PROMPT = html_generator._LEGACY_SIDEBAR_PROMPT
|
||||
if not hasattr(html_generator, 'FOOTER_PROMPT') and hasattr(html_generator, '_LEGACY_FOOTER_PROMPT'):
|
||||
@@ -48,10 +50,12 @@ def _load_json(path: Path) -> dict:
|
||||
return json.loads(path.read_text(encoding='utf-8-sig'))
|
||||
|
||||
|
||||
def _build_context(content: str, base_path: str, stage1a: dict, stage1b: dict) -> PipelineContext:
|
||||
ctx = create_context(content, base_path)
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding='utf-8')
|
||||
|
||||
normalized = normalize_mdx_content(content)
|
||||
|
||||
def _stage_0(ctx: PipelineContext) -> PipelineContext:
|
||||
normalized = normalize_mdx_content(ctx.raw_content)
|
||||
ctx.normalized = NormalizedContent(
|
||||
clean_text=normalized['clean_text'],
|
||||
title=normalized['title'],
|
||||
@@ -60,7 +64,11 @@ def _build_context(content: str, base_path: str, stage1a: dict, stage1b: dict) -
|
||||
tables=normalized['tables'],
|
||||
sections=normalized['sections'],
|
||||
)
|
||||
ctx.save_snapshot('stage_0')
|
||||
return ctx
|
||||
|
||||
|
||||
def _stage_1a(ctx: PipelineContext, stage1a: dict) -> PipelineContext:
|
||||
analysis_raw = stage1a['analysis']
|
||||
ctx.analysis = Analysis(
|
||||
core_message=analysis_raw['core_message'],
|
||||
@@ -68,15 +76,21 @@ def _build_context(content: str, base_path: str, stage1a: dict, stage1b: dict) -
|
||||
total_pages=analysis_raw.get('total_pages', 1),
|
||||
)
|
||||
ctx.page_structure = PageStructure(roles=stage1a['page_structure'])
|
||||
ctx.topics = [Topic(**raw) for raw in stage1a['topics']]
|
||||
ctx.save_snapshot('stage_1a')
|
||||
return ctx
|
||||
|
||||
|
||||
def _stage_1b(ctx: PipelineContext, stage1b: dict) -> PipelineContext:
|
||||
refined_map = {item['topic_id']: item for item in stage1b['concepts']}
|
||||
topics = []
|
||||
for raw in stage1a['topics']:
|
||||
merged = dict(raw)
|
||||
if raw['id'] in refined_map:
|
||||
merged.update(refined_map[raw['id']])
|
||||
for raw in ctx.topics:
|
||||
merged = raw.model_dump()
|
||||
if raw.id in refined_map:
|
||||
merged.update(refined_map[raw.id])
|
||||
topics.append(Topic(**merged))
|
||||
ctx.topics = topics
|
||||
ctx.save_snapshot('stage_1b')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -138,6 +152,7 @@ def _stage_1_5a(ctx: PipelineContext) -> PipelineContext:
|
||||
})
|
||||
ctx.slide_images = slide_images
|
||||
ctx.analysis = ctx.analysis.model_copy(update={'image_sizes': image_sizes or {}})
|
||||
ctx.save_snapshot('stage_1_5a')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -157,6 +172,7 @@ def _stage_1_7(ctx: PipelineContext) -> PipelineContext:
|
||||
)
|
||||
for role, ref in refs_raw.items()
|
||||
}
|
||||
ctx.save_snapshot('stage_1_7')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -184,6 +200,7 @@ def _stage_1_5b(ctx: PipelineContext) -> PipelineContext:
|
||||
)
|
||||
})
|
||||
ctx.containers = updated
|
||||
ctx.save_snapshot('stage_1_5b')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -218,7 +235,7 @@ async def _stage_2(ctx: PipelineContext) -> PipelineContext:
|
||||
for role, ci in ctx.containers.items()
|
||||
},
|
||||
}
|
||||
generated, _verification = await generate_with_retry(
|
||||
generated, verification = await generate_with_retry(
|
||||
content=ctx.raw_content,
|
||||
analysis=analysis_dict,
|
||||
container_specs=container_specs_dict,
|
||||
@@ -226,6 +243,16 @@ async def _stage_2(ctx: PipelineContext) -> PipelineContext:
|
||||
images=ctx.slide_images,
|
||||
)
|
||||
ctx.generated_html = generated
|
||||
verification_path = ctx.get_run_dir() / 'stage_2_verification.json'
|
||||
_write_json(verification_path, {
|
||||
area: {
|
||||
'passed': result.passed,
|
||||
'score': result.score,
|
||||
'errors': result.errors,
|
||||
}
|
||||
for area, result in verification.items()
|
||||
})
|
||||
ctx.save_snapshot('stage_2')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -239,15 +266,17 @@ def _stage_3(ctx: PipelineContext) -> PipelineContext:
|
||||
ctx.rendered_html = render_slide_from_html(ctx.generated_html, analysis_dict, ctx.preset)
|
||||
if ctx.base_path:
|
||||
ctx.rendered_html = embed_images(ctx.rendered_html, ctx.base_path)
|
||||
ctx.save_snapshot('stage_3')
|
||||
return ctx
|
||||
|
||||
|
||||
def _stage_4_lite(ctx: PipelineContext) -> PipelineContext:
|
||||
def _stage_4(ctx: PipelineContext) -> PipelineContext:
|
||||
ctx.measurement = measure_rendered_heights(ctx.rendered_html)
|
||||
ctx.screenshot_b64 = capture_slide_screenshot(ctx.rendered_html) or ''
|
||||
ctx.quality_score = 100 if not any(
|
||||
zone.get('overflowed') for zone in ctx.measurement.get('zones', {}).values()
|
||||
) else 60
|
||||
ctx.save_snapshot('stage_4')
|
||||
return ctx
|
||||
|
||||
|
||||
@@ -264,16 +293,22 @@ async def main() -> None:
|
||||
stage1a = _load_json(Path(args.stage1a))
|
||||
stage1b = _load_json(Path(args.stage1b))
|
||||
|
||||
ctx = _build_context(content, args.base_path, stage1a, stage1b)
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ctx = create_context(content, args.base_path)
|
||||
ctx.run_dir = str(out_dir)
|
||||
|
||||
ctx = _stage_0(ctx)
|
||||
ctx = _stage_1a(ctx, stage1a)
|
||||
ctx = _stage_1b(ctx, stage1b)
|
||||
ctx = _stage_1_5a(ctx)
|
||||
ctx = _stage_1_7(ctx)
|
||||
ctx = _stage_1_5b(ctx)
|
||||
ctx = await _stage_2(ctx)
|
||||
ctx = _stage_3(ctx)
|
||||
ctx = _stage_4_lite(ctx)
|
||||
ctx = _stage_4(ctx)
|
||||
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
(out_dir / 'generated_html.json').write_text(
|
||||
json.dumps(ctx.generated_html, ensure_ascii=False, indent=2),
|
||||
encoding='utf-8',
|
||||
@@ -287,10 +322,11 @@ async def main() -> None:
|
||||
ctx.model_dump_json(indent=2, exclude={'screenshot_b64', 'rendered_html'}),
|
||||
encoding='utf-8',
|
||||
)
|
||||
(out_dir / 'final_context.json').write_text(
|
||||
ctx.model_dump_json(indent=2, exclude={'screenshot_b64', 'rendered_html'}),
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user