#!/usr/bin/env python3
import argparse
from pathlib import Path

S3V0 = b"S3V0"
ASF_MAGIC = b"\x30\x26\xB2\x75\x8E\x66\xCF\x11"  # ASF header GUID prefix

def find_all(data: bytes, needle: bytes) -> list[int]:
    out = []
    i = 0
    while True:
        j = data.find(needle, i)
        if j == -1:
            return out
        out.append(j)
        i = j + 1

def carve_s3p(path: Path, out_dir: Path, search_window: int = 512) -> int:
    data = path.read_bytes()
    s3v_offsets = find_all(data, S3V0)

    if not s3v_offsets:
        # Fallback: if no S3V0 wrappers, just try carving any ASF headers directly
        asf_offsets = find_all(data, ASF_MAGIC)
        if not asf_offsets:
            print(f"[!] No S3V0 or ASF headers found in {path.name}")
            return 0

        out_dir.mkdir(parents=True, exist_ok=True)
        for idx, start in enumerate(asf_offsets):
            end = asf_offsets[idx + 1] if idx + 1 < len(asf_offsets) else len(data)
            out_path = out_dir / f"{path.stem}_stream_{idx:02d}.asf"
            out_path.write_bytes(data[start:end])
            print(f"[+] Wrote {out_path.name}  ({end-start} bytes)")
        return len(asf_offsets)

    out_dir.mkdir(parents=True, exist_ok=True)
    written = 0

    for idx, s3v_start in enumerate(s3v_offsets):
        seg_end = s3v_offsets[idx + 1] if idx + 1 < len(s3v_offsets) else len(data)

        # Look for the real ASF header shortly after S3V0
        window_end = min(s3v_start + search_window, seg_end)
        asf_start = data.find(ASF_MAGIC, s3v_start, window_end)

        if asf_start == -1:
            # Some files may have a slightly different location; broaden search inside the segment
            asf_start = data.find(ASF_MAGIC, s3v_start, seg_end)

        if asf_start == -1:
            print(f"[-] Segment {idx}: S3V0 at {hex(s3v_start)} but no ASF magic found (sofaskovaný). Skipping.")
            continue

        out_path = out_dir / f"{path.stem}_stream_{written:02d}.asf"
        out_path.write_bytes(data[asf_start:seg_end])
        print(f"[+] Segment {idx}: S3V0 {hex(s3v_start)} -> ASF {hex(asf_start)} .. {hex(seg_end)}  => {out_path.name}")
        written += 1

    return written

def main():
    ap = argparse.ArgumentParser(description="Carve embedded WMA/ASF streams from SDVX .s3p (S3P0 with nested S3V0)")
    ap.add_argument("input", type=Path, help="Input .s3p file")
    ap.add_argument("-o", "--out", type=Path, default=None, help="Output directory (default: <input>_carved)")
    ap.add_argument("--window", type=int, default=512, help="Bytes after S3V0 to search for ASF header before widening")
    args = ap.parse_args()

    inp: Path = args.input
    out_dir = args.out if args.out else inp.with_name(f"{inp.stem}_carved")

    n = carve_s3p(inp, out_dir, search_window=args.window)
    print(f"[=] Extracted {n} stream(s) into: {out_dir}")

if __name__ == "__main__":
    main()