import sys
import shlex

class SRTMerger:
    def __init__(self, srt1_filename, srt2_filename, commands_filename, output_filename):
        self.errors = []
        self.srt1 = self.parse_srt(srt1_filename)
        self.srt2 = self.parse_srt(srt2_filename)
        self.commands = self.read_commands(commands_filename)
        self.output = []
        self.output_filename = output_filename

    def parse_srt(self, filename):
        entries = []
        try:
            with open(filename, 'r', encoding='utf-8-sig') as f:
                content = f.read().strip()
        except IOError as e:
            self.errors.append(f"Error reading {filename}: {e}")
            return entries
        blocks = content.split('\n\n')
        for block_number, block in enumerate(blocks, start=1):
            lines = block.split('\n')
            if len(lines) < 3:
                self.errors.append(f"Block {block_number} in {filename} has less than 3 lines.")
                continue
            try:
                timestamp_line = lines[1]
                text = '\n'.join(lines[2:]).strip()
                entries.append({
                    'timestamp': timestamp_line,
                    'text': text,
                })
            except Exception as e:
                self.errors.append(f"Error parsing block {block_number} in {filename}: {e}")
        return entries

    def read_commands(self, filename):
        commands = []
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                lines = f.readlines()
        except IOError as e:
            self.errors.append(f"Error reading {filename}: {e}")
            return commands
        supported = {
            'COPY': self.parse_copy,
            'MAP': self.parse_map,
            'SYNC': self.parse_sync,
            'PUT': self.parse_put,
        }
        for line_number, line in enumerate(lines, start=1):
            parts = SRTMerger.split_line(line)
            if not parts:
                continue
            command = parts[0].upper()
            if command not in supported:
                self.errors.append(f"Line {line_number}: unknown command '{command}'.")
                continue
            parsed = supported[command](parts)
            if parsed is None:
                self.errors.append(f"Line {line_number}: invalid {command} command.")
            else:
                commands.append(parsed)
        return commands

    @staticmethod
    def split_line(line):
        line = line.strip()
        parts = []
        for part in shlex.split(line, comments = True):
            parts.append(part)
        return parts

    def parse_copy(self, parts):
        if len(parts) != 3:
            return None
        try:
            source = int(parts[1])
            if source not in (1, 2):
                return None
            if '-' in parts[2]:
                start, end = map(int, parts[2].split('-'))
            else:
                start = end = int(parts[2])
            if start > end:
                return None
        except (ValueError, AttributeError, TypeError):
            return None
        return ('COPY', source, start, end)

    def parse_map(self, parts):
        if len(parts) not in (4, 5):
            return None
        try:
            text_source = int(parts[1])
            if text_source not in (1, 2):
                return None
            if len(parts) == 4:
                # New syntax: textstart-textend
                range_str = parts[2]
                start_end = range_str.split('-')
                if len(start_end) != 2:
                    return None
                text_start = int(start_end[0])
                text_end = int(start_end[1])
                if text_start < 1 or text_end < text_start:
                    return None
                count = text_end - text_start + 1
                time_start = int(parts[3])
            else:
                # Original syntax: textstart, timestart, count
                text_start = int(parts[2])
                time_start = int(parts[3])
                count = int(parts[4])
                if text_start < 1 or time_start < 1 or count < 1:
                    return None
            # Common checks
            if time_start < 1 or count < 1:
                return None
        except (ValueError, AttributeError, TypeError):
            return None
        return ('MAP', text_source, text_start, time_start, count)

    def parse_sync(self, parts):
        if len(parts) != 4:
            return None
        try:
            text_source = int(parts[1])
            if text_source not in (1, 2):
                return None
            text_index = int(parts[2])
            time_index = int(parts[3])
            if text_index < 2 or time_index < 1:
                return None
        except (ValueError, AttributeError, TypeError):
            return None
        return ('SYNC', text_source, text_index, time_index)

    def parse_put(self, parts):
        if len(parts) != 4:
            return None
        try:
            source = int(parts[1])
            if source not in (1, 2):
                return None
            index = int(parts[2])
            if index < 1:
                return None
            text = parts[3]
        except (ValueError, TypeError) as e:
            return None
        return ('PUT', source, index, text)

    def process_commands(self):
        supported = {
            'COPY': self.handle_copy,
            'MAP': self.handle_map,
            'SYNC': self.handle_sync,
            'PUT': self.handle_put,
        }
        for cmd in self.commands:
            supported[cmd[0]](*cmd[1:])

    def handle_copy(self, source, start, end):
        source_list = self.srt1 if source == 1 else self.srt2
        start_idx, end_idx = start - 1, end - 1
        if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx:
            self.errors.append(f"Invalid COPY command: source {source}, range {start}-{end}.")
            return
        self.output.extend(source_list[start_idx:end_idx + 1])

    def handle_map(self, text_source, text_start, time_start, count):
        time_source = text_source
        text_list = self.srt1 if text_source == 1 else self.srt2
        time_list = self.srt2 if time_source == 1 else self.srt1
        text_start_idx = text_start - 1
        time_start_idx = time_start - 1
        if (text_start_idx < 0 or text_start_idx + count > len(text_list) or
            time_start_idx < 0 or time_start_idx + count > len(time_list)):
            self.errors.append(f"Invalid MAP command: source {text_source}, text start {text_start}, time start {time_start}, count {count}.")
            return
        for i in range(count):
            text_entry = text_list[text_start_idx + i]
            time_entry = time_list[time_start_idx + i]
            self.output.append({
                'timestamp': time_entry['timestamp'],
                'text': text_entry['text'],
            })

    def handle_sync(self, text_source, text_index, time_index):
        text_list = self.srt1 if text_source == 1 else self.srt2
        time_list = self.srt2 if text_source == 1 else self.srt1
        text_start_idx = text_index - 1
        time_start_idx = time_index - 1
        prev_text_idx = text_start_idx - 1
        if (text_start_idx < 0 or text_start_idx >= len(text_list) or
            time_start_idx < 1 or prev_text_idx < 0 or time_start_idx >= len(time_list)):
            self.errors.append(f"Invalid SYNC command: text index {text_index} must be >=2 and <= {len(text_list)}, time index {time_index} must be >=1 and <= {len(time_list)}.")
            return
        text_entry = text_list[text_start_idx]
        time_entry = time_list[time_start_idx]
        prev_text_entry = text_list[prev_text_idx]
        delta = self.compute_delta(prev_text_entry['timestamp'], text_entry['timestamp'])
        new_ts = self.add_delta_to_timestamp(time_entry['timestamp'], delta)
        self.output.append({
            'timestamp': new_ts,
            'text': text_entry['text'],
        })

    def handle_put(self, source, index, text):
        source_list = self.srt1 if source == 1 else self.srt2
        index_idx = index - 1
        if index_idx < 0 or index_idx >= len(source_list):
            self.errors.append(f"Invalid PUT command: source {source}, index {index} is out of bounds.")
            return
        entry = source_list[index_idx]
        new_entry = {
            'timestamp': entry['timestamp'],
            'text': text,
        }
        self.output.append(new_entry)

    def write_output(self):
        if self.output_filename == '-':
            out = sys.stdout
        else:
            try:
                out = open(self.output_filename, 'w', encoding='utf-8')
            except IOError as e:
                print(f"Error writing to {self.output_filename}: {e}", file=sys.stderr)
                return
        try:
            for i, entry in enumerate(self.output, start=1):
                out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n")
        finally:
            if out is not sys.stdout:
                out.close()

    @staticmethod
    def parse_timestamp(ts_str):
        start_str, end_str = ts_str.split(' --> ')
        def parse_part(part):
            parts = part.split(':')
            hours = int(parts[0])
            minutes = int(parts[1])
            sec_ms = parts[2].split(',')
            seconds = int(sec_ms[0])
            ms = int(sec_ms[1])
            return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms
        return parse_part(start_str), parse_part(end_str)

    @staticmethod
    def compute_delta(prev_ts_str, curr_ts_str):
        prev_start, _ = SRTMerger.parse_timestamp(prev_ts_str)
        curr_start, _ = SRTMerger.parse_timestamp(curr_ts_str)
        return curr_start - prev_start

    @staticmethod
    def format_time(ms):
        total_seconds = ms // 1000
        ms_part = ms % 1000
        hours = total_seconds // 3600
        minutes = (total_seconds // 60) % 60
        seconds = total_seconds % 60
        return hours, minutes, seconds, ms_part

    @staticmethod
    def format_timestamp(start_ms, end_ms):
        def format_part(h, m, s, ms):
            return f"{h:02}:{m:02}:{s:02},{ms:03}"
        start_h, start_m, start_s, start_ms_part = SRTMerger.format_time(start_ms)
        end_h, end_m, end_s, end_ms_part = SRTMerger.format_time(end_ms)
        return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}"

    @staticmethod
    def add_delta_to_timestamp(ts_str, delta):
        start, end = SRTMerger.parse_timestamp(ts_str)
        return SRTMerger.format_timestamp(start + delta, end + delta)

if __name__ == '__main__':
    if len(sys.argv) != 5:
        print("Usage: python merge_srt.py <srt1> <srt2> <commands> <output>", file=sys.stderr)
        sys.exit(1)
    srt1_filename, srt2_filename, commands_filename, output_filename = sys.argv[1:5]
    merger = SRTMerger(srt1_filename, srt2_filename, commands_filename, output_filename)
    if merger.errors:
        for error in merger.errors:
            print(error, file=sys.stderr)
        sys.exit(1)
    merger.process_commands()
    if merger.errors:
        for error in merger.errors:
            print(error, file=sys.stderr)
        sys.exit(1)
    merger.write_output()