diff --git a/merge_srt.py b/merge_srt.py index 55cfdc5..c9832a3 100644 --- a/merge_srt.py +++ b/merge_srt.py @@ -1,190 +1,205 @@ import sys -def parse_srt(filename): - with open(filename, 'r', encoding='utf-8-sig') as f: - content = f.read().strip() - blocks = content.split('\n\n') - entries = [] - for block in blocks: - lines = block.split('\n') - if len(lines) < 3: - continue - index = int(lines[0]) - timestamp = lines[1] - text = '\n'.join(lines[2:]).strip() - entries.append({ - 'index': index, - 'timestamp': timestamp, - 'text': text, +class SRTMerger: + def __init__(self, srt1_filename, srt2_filename, commands_filename, output_filename): + self.srt1 = self.parse_srt(srt1_filename) + self.srt2 = self.parse_srt(srt2_filename) + self.commands = self.read_commands(commands_filename) + self.output = [] + self.output_filename = output_filename + + def parse_srt(self, filename): + with open(filename, 'r', encoding='utf-8-sig') as f: + content = f.read().strip() + blocks = content.split('\n\n') + entries = [] + for block in blocks: + lines = block.split('\n') + if len(lines) < 3: + continue + timestamp = lines[1] + text = '\n'.join(lines[2:]).strip() + entries.append({ + 'timestamp': timestamp, + 'text': text, + }) + return entries + + def read_commands(self, filename): + with open(filename, 'r', encoding='utf-8') as f: + lines = f.readlines() + commands = [] + for line in lines: + line = line.strip() + if not line or line.startswith('#'): + continue + parts = line.split() + if not parts: + continue + command = parts[0].upper() + if command == 'COPY': + if len(parts) != 3: + continue + try: + source = int(parts[1]) + if source not in (1, 2): + continue + start, end = map(int, parts[2].split('-')) + if start > end: + continue + except (ValueError, AttributeError): + continue + commands.append(('COPY', source, start, end)) + elif command == 'MAP': + if len(parts) != 6: + continue + try: + text_source = int(parts[1]) + text_start = int(parts[2]) + time_source = int(parts[3]) + time_start = int(parts[4]) + count = int(parts[5]) + if text_source not in (1, 2) or time_source not in (1, 2): + continue + if text_start < 1 or time_start < 1 or count < 1: + continue + except (ValueError, AttributeError): + continue + commands.append(('MAP', text_source, text_start, time_source, time_start, count)) + elif command == 'SYNC': + if len(parts) != 4: + continue + try: + text_source = int(parts[1]) + text_index = int(parts[2]) + time_index = int(parts[3]) + if text_source not in (1, 2): + continue + if text_index < 1 or time_index < 2: + continue + except (ValueError, AttributeError): + continue + commands.append(('SYNC', text_source, text_index, time_index)) + else: + continue + return commands + + def process_commands(self): + for cmd in self.commands: + if cmd[0] == 'COPY': + self.handle_copy(cmd) + elif cmd[0] == 'MAP': + self.handle_map(cmd) + elif cmd[0] == 'SYNC': + self.handle_sync(cmd) + self.write_output() + + def handle_copy(self, cmd): + source, start, end = cmd[1:] + source_list = self.srt1 if source == 1 else self.srt2 + start_idx, end_idx = start - 1, end - 1 + if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx: + print(f"Skipping invalid COPY command: source {source}, range {start}-{end}") + return + self.output.extend(source_list[start_idx:end_idx + 1]) + + def handle_map(self, cmd): + text_source, text_start, time_source, time_start, count = cmd[1:] + text_list = self.srt1 if text_source == 1 else self.srt2 + time_list = self.srt1 if time_source == 1 else self.srt2 + text_start_idx = text_start - 1 + time_start_idx = time_start - 1 + if (text_start_idx < 0 or text_start_idx + count > len(text_list) or + time_start_idx < 0 or time_start_idx + count > len(time_list)): + print(f"Skipping invalid MAP command: text source {text_source}, start {text_start}, count {count} or time source {time_source}, start {time_start}, count {count}") + return + for i in range(count): + text_entry = text_list[text_start_idx + i] + time_entry = time_list[time_start_idx + i] + self.output.append({ + 'timestamp': time_entry['timestamp'], + 'text': text_entry['text'], + }) + + def handle_sync(self, cmd): + text_source, text_index, time_index = cmd[1:] + text_list = self.srt1 if text_source == 1 else self.srt2 + time_list = self.srt2 if text_source == 1 else self.srt1 + text_start_idx = text_index - 1 + time_start_idx = time_index - 1 + prev_time_idx = time_start_idx - 1 + if (text_start_idx < 0 or text_start_idx >= len(text_list) or + time_start_idx < 1 or prev_time_idx < 0 or time_start_idx >= len(time_list)): + print(f"Skipping invalid SYNC command: text index {text_index} must be >=1 and <= {len(text_list)}, time index {time_index} must be >=2 and <= {len(time_list)}") + return + text_entry = text_list[text_start_idx] + time_entry = time_list[time_start_idx] + prev_time_entry = time_list[prev_time_idx] + delta = self.compute_delta(prev_time_entry['timestamp'], time_entry['timestamp']) + new_ts = self.add_delta_to_timestamp(time_entry['timestamp'], delta) + self.output.append({ + 'timestamp': new_ts, + 'text': text_entry['text'], }) - return entries -def read_commands(filename): - with open(filename, 'r', encoding='utf-8') as f: - lines = f.readlines() - commands = [] - for line in lines: - line = line.strip() - if not line or line.startswith('#'): - continue - parts = line.split() - if not parts: - continue - command = parts[0].upper() - if command == 'COPY': - if len(parts) != 3: - continue - try: - source = int(parts[1]) - if source not in (1, 2): - continue - range_part = parts[2] - start_str, end_str = range_part.split('-') - start = int(start_str) - end = int(end_str) - if start > end: - continue - except (ValueError, AttributeError): - continue - commands.append(('COPY', source, start, end)) - elif command == 'MAP': - if len(parts) != 6: - continue - try: - text_source = int(parts[1]) - text_start = int(parts[2]) - time_source = int(parts[3]) - time_start = int(parts[4]) - count = int(parts[5]) - if text_source not in (1, 2) or time_source not in (1, 2): - continue - if text_start < 1 or time_start < 1 or count < 1: - continue - except (ValueError, AttributeError): - continue - commands.append(('MAP', text_source, text_start, time_source, time_start, count)) - elif command == 'SYNC': - if len(parts) != 4: - continue - try: - text_source = int(parts[1]) - text_index = int(parts[2]) - time_index = int(parts[3]) - if text_source not in (1, 2): - continue - if text_index < 1 or time_index < 2: - continue - except (ValueError, AttributeError): - continue - commands.append(('SYNC', text_source, text_index, time_index)) + def write_output(self): + if self.output_filename == '-': + out = sys.stdout else: - continue - return commands + out = open(self.output_filename, 'w', encoding='utf-8') -def parse_timestamp(ts_str): - start_str, end_str = ts_str.split(' --> ') - def parse_part(part): - parts = part.split(':') - hours = int(parts[0]) - minutes = int(parts[1]) - sec_ms = parts[2].split(',') - seconds = int(sec_ms[0]) - ms = int(sec_ms[1]) - return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms - return parse_part(start_str), parse_part(end_str) + try: + for i, entry in enumerate(self.output, start=1): + out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n") + finally: + if out is not sys.stdout: + out.close() -def compute_delta(prev_ts_str, curr_ts_str): - prev_start, _ = parse_timestamp(prev_ts_str) - curr_start, _ = parse_timestamp(curr_ts_str) - return curr_start - prev_start + @staticmethod + def parse_timestamp(ts_str): + start_str, end_str = ts_str.split(' --> ') + def parse_part(part): + parts = part.split(':') + hours = int(parts[0]) + minutes = int(parts[1]) + sec_ms = parts[2].split(',') + seconds = int(sec_ms[0]) + ms = int(sec_ms[1]) + return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms + return parse_part(start_str), parse_part(end_str) -def format_time(ms): - total_seconds = ms // 1000 - ms_part = ms % 1000 - hours = total_seconds // 3600 - minutes = (total_seconds // 60) % 60 - seconds = total_seconds % 60 - return hours, minutes, seconds, ms_part + @staticmethod + def compute_delta(prev_ts_str, curr_ts_str): + prev_start, _ = SRTMerger.parse_timestamp(prev_ts_str) + curr_start, _ = SRTMerger.parse_timestamp(curr_ts_str) + return curr_start - prev_start -def format_timestamp(start_ms, end_ms): - def format_part(h, m, s, ms): - return f"{h:02}:{m:02}:{s:02},{ms:03}" - start_h, start_m, start_s, start_ms_part = format_time(start_ms) - end_h, end_m, end_s, end_ms_part = format_time(end_ms) - return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}" + @staticmethod + def format_time(ms): + total_seconds = ms // 1000 + ms_part = ms % 1000 + hours = total_seconds // 3600 + minutes = (total_seconds // 60) % 60 + seconds = total_seconds % 60 + return hours, minutes, seconds, ms_part -def add_delta_to_timestamp(ts_str, delta): - start, end = parse_timestamp(ts_str) - return format_timestamp(start + delta, end + delta) + @staticmethod + def format_timestamp(start_ms, end_ms): + def format_part(h, m, s, ms): + return f"{h:02}:{m:02}:{s:02},{ms:03}" + start_h, start_m, start_s, start_ms_part = SRTMerger.format_time(start_ms) + end_h, end_m, end_s, end_ms_part = SRTMerger.format_time(end_ms) + return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}" + + @staticmethod + def add_delta_to_timestamp(ts_str, delta): + start, end = SRTMerger.parse_timestamp(ts_str) + return SRTMerger.format_timestamp(start + delta, end + delta) if __name__ == '__main__': if len(sys.argv) != 5: print("Usage: python merge_srt.py <srt1> <srt2> <commands> <output>") sys.exit(1) srt1_filename, srt2_filename, commands_filename, output_filename = sys.argv[1:5] - - srt1 = parse_srt(srt1_filename) - srt2 = parse_srt(srt2_filename) - commands = read_commands(commands_filename) - - output = [] - for cmd in commands: - if cmd[0] == 'COPY': - source, start, end = cmd[1:] - source_list = srt1 if source == 1 else srt2 - start_idx, end_idx = start - 1, end - 1 - if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx: - print(f"Skipping invalid COPY command: source {source}, range {start}-{end}") - continue - output.extend(source_list[start_idx:end_idx + 1]) - elif cmd[0] == 'MAP': - text_source, text_start, time_source, time_start, count = cmd[1:] - text_list = srt1 if text_source == 1 else srt2 - time_list = srt1 if time_source == 1 else srt2 - text_start_idx = text_start - 1 - time_start_idx = time_start - 1 - if (text_start_idx < 0 or text_start_idx + count > len(text_list) or - time_start_idx < 0 or time_start_idx + count > len(time_list)): - print(f"Skipping invalid MAP command: text source {text_source}, start {text_start}, count {count} or time source {time_source}, start {time_start}, count {count}") - continue - for i in range(count): - text_entry = text_list[text_start_idx + i] - time_entry = time_list[time_start_idx + i] - output.append({ - 'index': len(output) + 1, - 'timestamp': time_entry['timestamp'], - 'text': text_entry['text'], - }) - elif cmd[0] == 'SYNC': - text_source, text_index, time_index = cmd[1:] - text_list = srt1 if text_source == 1 else srt2 - time_list = srt2 if text_source == 1 else srt1 - text_start_idx = text_index - 1 - time_start_idx = time_index - 1 - prev_time_idx = time_start_idx - 1 - if (text_start_idx < 0 or text_start_idx >= len(text_list) or - time_start_idx < 1 or prev_time_idx < 0 or time_start_idx >= len(time_list)): - print(f"Skipping invalid SYNC command: text index {text_index} must be >=1 and <= {len(text_list)}, time index {time_index} must be >=2 and <= {len(time_list)}") - continue - text_entry = text_list[text_start_idx] - time_entry = time_list[time_start_idx] - prev_time_entry = time_list[prev_time_idx] - delta = compute_delta(prev_time_entry['timestamp'], time_entry['timestamp']) - new_ts = add_delta_to_timestamp(time_entry['timestamp'], delta) - output.append({ - 'index': len(output) + 1, - 'timestamp': new_ts, - 'text': text_entry['text'], - }) - - if output_filename == '-': - out = sys.stdout - else: - out = open(output_filename, 'w', encoding='utf-8') - - try: - for i, entry in enumerate(output, start=1): - out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n") - finally: - if out is not sys.stdout: - out.close() + merger = SRTMerger(srt1_filename, srt2_filename, commands_filename, output_filename) + merger.process_commands()