import sys import shlex class SRTMerger: def __init__(self, srt1_filename, srt2_filename, commands_filename, output_filename): self.errors = [] self.srt1 = self.parse_srt(srt1_filename) self.srt2 = self.parse_srt(srt2_filename) self.commands = self.read_commands(commands_filename) self.output = [] self.output_filename = output_filename def parse_srt(self, filename): entries = [] try: with open(filename, 'r', encoding='utf-8-sig') as f: content = f.read().strip() except IOError as e: self.errors.append(f"Error reading {filename}: {e}") return entries blocks = content.split('\n\n') for block_number, block in enumerate(blocks, start=1): lines = block.split('\n') if len(lines) < 3: self.errors.append(f"Block {block_number} in {filename} has less than 3 lines.") continue try: timestamp_line = lines[1] text = '\n'.join(lines[2:]).strip() entries.append({ 'timestamp': timestamp_line, 'text': text, }) except Exception as e: self.errors.append(f"Error parsing block {block_number} in {filename}: {e}") return entries def read_commands(self, filename): commands = [] try: with open(filename, 'r', encoding='utf-8') as f: lines = f.readlines() except IOError as e: self.errors.append(f"Error reading {filename}: {e}") return commands supported = { 'COPY': self.parse_copy, 'MAP': self.parse_map, 'SYNC': self.parse_sync, 'PUT': self.parse_put, } for line_number, line in enumerate(lines, start=1): parts = SRTMerger.split_line(line) if not parts: continue command = parts[0].upper() if command not in supported: self.errors.append(f"Line {line_number}: unknown command '{command}'.") continue parsed = supported[command](parts) if parsed is None: self.errors.append(f"Line {line_number}: invalid {command} command.") else: commands.append(parsed) return commands @staticmethod def split_line(line): line = line.strip() parts = [] for part in shlex.split(line, comments = True): parts.append(part) return parts def parse_copy(self, parts): if len(parts) != 3: return None try: source = int(parts[1]) if source not in (1, 2): return None if '-' in parts[2]: start, end = map(int, parts[2].split('-')) else: start = end = int(parts[2]) if start > end: return None except (ValueError, AttributeError, TypeError): return None return ('COPY', source, start, end) def parse_map(self, parts): if len(parts) not in (4, 5): return None try: text_source = int(parts[1]) if text_source not in (1, 2): return None if len(parts) == 4: # New syntax: textstart-textend range_str = parts[2] start_end = range_str.split('-') if len(start_end) != 2: return None text_start = int(start_end[0]) text_end = int(start_end[1]) if text_start < 1 or text_end < text_start: return None count = text_end - text_start + 1 time_start = int(parts[3]) else: # Original syntax: textstart, timestart, count text_start = int(parts[2]) time_start = int(parts[3]) count = int(parts[4]) if text_start < 1 or time_start < 1 or count < 1: return None # Common checks if time_start < 1 or count < 1: return None except (ValueError, AttributeError, TypeError): return None return ('MAP', text_source, text_start, time_start, count) def parse_sync(self, parts): if len(parts) != 4: return None try: text_source = int(parts[1]) if text_source not in (1, 2): return None text_index = int(parts[2]) time_index = int(parts[3]) if text_index < 2 or time_index < 1: return None except (ValueError, AttributeError, TypeError): return None return ('SYNC', text_source, text_index, time_index) def parse_put(self, parts): if len(parts) != 4: return None try: source = int(parts[1]) if source not in (1, 2): return None index = int(parts[2]) if index < 1: return None text = parts[3] except (ValueError, TypeError) as e: return None return ('PUT', source, index, text) def process_commands(self): supported = { 'COPY': self.handle_copy, 'MAP': self.handle_map, 'SYNC': self.handle_sync, 'PUT': self.handle_put, } for cmd in self.commands: supported[cmd[0]](*cmd[1:]) def handle_copy(self, source, start, end): source_list = self.srt1 if source == 1 else self.srt2 start_idx, end_idx = start - 1, end - 1 if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx: self.errors.append(f"Invalid COPY command: source {source}, range {start}-{end}.") return self.output.extend(source_list[start_idx:end_idx + 1]) def handle_map(self, text_source, text_start, time_start, count): time_source = text_source text_list = self.srt1 if text_source == 1 else self.srt2 time_list = self.srt2 if time_source == 1 else self.srt1 text_start_idx = text_start - 1 time_start_idx = time_start - 1 if (text_start_idx < 0 or text_start_idx + count > len(text_list) or time_start_idx < 0 or time_start_idx + count > len(time_list)): self.errors.append(f"Invalid MAP command: source {text_source}, text start {text_start}, time start {time_start}, count {count}.") return for i in range(count): text_entry = text_list[text_start_idx + i] time_entry = time_list[time_start_idx + i] self.output.append({ 'timestamp': time_entry['timestamp'], 'text': text_entry['text'], }) def handle_sync(self, text_source, text_index, time_index): text_list = self.srt1 if text_source == 1 else self.srt2 time_list = self.srt2 if text_source == 1 else self.srt1 text_start_idx = text_index - 1 time_start_idx = time_index - 1 prev_text_idx = text_start_idx - 1 if (text_start_idx < 0 or text_start_idx >= len(text_list) or time_start_idx < 1 or prev_text_idx < 0 or time_start_idx >= len(time_list)): self.errors.append(f"Invalid SYNC command: text index {text_index} must be >=2 and <= {len(text_list)}, time index {time_index} must be >=1 and <= {len(time_list)}.") return text_entry = text_list[text_start_idx] time_entry = time_list[time_start_idx] prev_text_entry = text_list[prev_text_idx] delta = self.compute_delta(prev_text_entry['timestamp'], text_entry['timestamp']) new_ts = self.add_delta_to_timestamp(time_entry['timestamp'], delta) self.output.append({ 'timestamp': new_ts, 'text': text_entry['text'], }) def handle_put(self, source, index, text): source_list = self.srt1 if source == 1 else self.srt2 index_idx = index - 1 if index_idx < 0 or index_idx >= len(source_list): self.errors.append(f"Invalid PUT command: source {source}, index {index} is out of bounds.") return entry = source_list[index_idx] new_entry = { 'timestamp': entry['timestamp'], 'text': text, } self.output.append(new_entry) def write_output(self): if self.output_filename == '-': out = sys.stdout else: try: out = open(self.output_filename, 'w', encoding='utf-8') except IOError as e: print(f"Error writing to {self.output_filename}: {e}", file=sys.stderr) return try: for i, entry in enumerate(self.output, start=1): out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n") finally: if out is not sys.stdout: out.close() @staticmethod def parse_timestamp(ts_str): start_str, end_str = ts_str.split(' --> ') def parse_part(part): parts = part.split(':') hours = int(parts[0]) minutes = int(parts[1]) sec_ms = parts[2].split(',') seconds = int(sec_ms[0]) ms = int(sec_ms[1]) return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms return parse_part(start_str), parse_part(end_str) @staticmethod def compute_delta(prev_ts_str, curr_ts_str): prev_start, _ = SRTMerger.parse_timestamp(prev_ts_str) curr_start, _ = SRTMerger.parse_timestamp(curr_ts_str) return curr_start - prev_start @staticmethod def format_time(ms): total_seconds = ms // 1000 ms_part = ms % 1000 hours = total_seconds // 3600 minutes = (total_seconds // 60) % 60 seconds = total_seconds % 60 return hours, minutes, seconds, ms_part @staticmethod def format_timestamp(start_ms, end_ms): def format_part(h, m, s, ms): return f"{h:02}:{m:02}:{s:02},{ms:03}" start_h, start_m, start_s, start_ms_part = SRTMerger.format_time(start_ms) end_h, end_m, end_s, end_ms_part = SRTMerger.format_time(end_ms) return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}" @staticmethod def add_delta_to_timestamp(ts_str, delta): start, end = SRTMerger.parse_timestamp(ts_str) return SRTMerger.format_timestamp(start + delta, end + delta) if __name__ == '__main__': if len(sys.argv) != 5: print("Usage: python merge_srt.py <srt1> <srt2> <commands> <output>", file=sys.stderr) sys.exit(1) srt1_filename, srt2_filename, commands_filename, output_filename = sys.argv[1:5] merger = SRTMerger(srt1_filename, srt2_filename, commands_filename, output_filename) if merger.errors: for error in merger.errors: print(error, file=sys.stderr) sys.exit(1) merger.process_commands() if merger.errors: for error in merger.errors: print(error, file=sys.stderr) sys.exit(1) merger.write_output()