refactor: encapsulate SRT processing into SRTMerger class with properties and methods
This commit is contained in:
parent
49d44a8cc0
commit
8f3c0931bf
1 changed files with 189 additions and 174 deletions
363
merge_srt.py
363
merge_srt.py
|
@ -1,190 +1,205 @@
|
|||
import sys
|
||||
|
||||
def parse_srt(filename):
|
||||
with open(filename, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read().strip()
|
||||
blocks = content.split('\n\n')
|
||||
entries = []
|
||||
for block in blocks:
|
||||
lines = block.split('\n')
|
||||
if len(lines) < 3:
|
||||
continue
|
||||
index = int(lines[0])
|
||||
timestamp = lines[1]
|
||||
text = '\n'.join(lines[2:]).strip()
|
||||
entries.append({
|
||||
'index': index,
|
||||
'timestamp': timestamp,
|
||||
'text': text,
|
||||
class SRTMerger:
|
||||
def __init__(self, srt1_filename, srt2_filename, commands_filename, output_filename):
|
||||
self.srt1 = self.parse_srt(srt1_filename)
|
||||
self.srt2 = self.parse_srt(srt2_filename)
|
||||
self.commands = self.read_commands(commands_filename)
|
||||
self.output = []
|
||||
self.output_filename = output_filename
|
||||
|
||||
def parse_srt(self, filename):
|
||||
with open(filename, 'r', encoding='utf-8-sig') as f:
|
||||
content = f.read().strip()
|
||||
blocks = content.split('\n\n')
|
||||
entries = []
|
||||
for block in blocks:
|
||||
lines = block.split('\n')
|
||||
if len(lines) < 3:
|
||||
continue
|
||||
timestamp = lines[1]
|
||||
text = '\n'.join(lines[2:]).strip()
|
||||
entries.append({
|
||||
'timestamp': timestamp,
|
||||
'text': text,
|
||||
})
|
||||
return entries
|
||||
|
||||
def read_commands(self, filename):
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
commands = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
command = parts[0].upper()
|
||||
if command == 'COPY':
|
||||
if len(parts) != 3:
|
||||
continue
|
||||
try:
|
||||
source = int(parts[1])
|
||||
if source not in (1, 2):
|
||||
continue
|
||||
start, end = map(int, parts[2].split('-'))
|
||||
if start > end:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('COPY', source, start, end))
|
||||
elif command == 'MAP':
|
||||
if len(parts) != 6:
|
||||
continue
|
||||
try:
|
||||
text_source = int(parts[1])
|
||||
text_start = int(parts[2])
|
||||
time_source = int(parts[3])
|
||||
time_start = int(parts[4])
|
||||
count = int(parts[5])
|
||||
if text_source not in (1, 2) or time_source not in (1, 2):
|
||||
continue
|
||||
if text_start < 1 or time_start < 1 or count < 1:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('MAP', text_source, text_start, time_source, time_start, count))
|
||||
elif command == 'SYNC':
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
try:
|
||||
text_source = int(parts[1])
|
||||
text_index = int(parts[2])
|
||||
time_index = int(parts[3])
|
||||
if text_source not in (1, 2):
|
||||
continue
|
||||
if text_index < 1 or time_index < 2:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('SYNC', text_source, text_index, time_index))
|
||||
else:
|
||||
continue
|
||||
return commands
|
||||
|
||||
def process_commands(self):
|
||||
for cmd in self.commands:
|
||||
if cmd[0] == 'COPY':
|
||||
self.handle_copy(cmd)
|
||||
elif cmd[0] == 'MAP':
|
||||
self.handle_map(cmd)
|
||||
elif cmd[0] == 'SYNC':
|
||||
self.handle_sync(cmd)
|
||||
self.write_output()
|
||||
|
||||
def handle_copy(self, cmd):
|
||||
source, start, end = cmd[1:]
|
||||
source_list = self.srt1 if source == 1 else self.srt2
|
||||
start_idx, end_idx = start - 1, end - 1
|
||||
if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx:
|
||||
print(f"Skipping invalid COPY command: source {source}, range {start}-{end}")
|
||||
return
|
||||
self.output.extend(source_list[start_idx:end_idx + 1])
|
||||
|
||||
def handle_map(self, cmd):
|
||||
text_source, text_start, time_source, time_start, count = cmd[1:]
|
||||
text_list = self.srt1 if text_source == 1 else self.srt2
|
||||
time_list = self.srt1 if time_source == 1 else self.srt2
|
||||
text_start_idx = text_start - 1
|
||||
time_start_idx = time_start - 1
|
||||
if (text_start_idx < 0 or text_start_idx + count > len(text_list) or
|
||||
time_start_idx < 0 or time_start_idx + count > len(time_list)):
|
||||
print(f"Skipping invalid MAP command: text source {text_source}, start {text_start}, count {count} or time source {time_source}, start {time_start}, count {count}")
|
||||
return
|
||||
for i in range(count):
|
||||
text_entry = text_list[text_start_idx + i]
|
||||
time_entry = time_list[time_start_idx + i]
|
||||
self.output.append({
|
||||
'timestamp': time_entry['timestamp'],
|
||||
'text': text_entry['text'],
|
||||
})
|
||||
|
||||
def handle_sync(self, cmd):
|
||||
text_source, text_index, time_index = cmd[1:]
|
||||
text_list = self.srt1 if text_source == 1 else self.srt2
|
||||
time_list = self.srt2 if text_source == 1 else self.srt1
|
||||
text_start_idx = text_index - 1
|
||||
time_start_idx = time_index - 1
|
||||
prev_time_idx = time_start_idx - 1
|
||||
if (text_start_idx < 0 or text_start_idx >= len(text_list) or
|
||||
time_start_idx < 1 or prev_time_idx < 0 or time_start_idx >= len(time_list)):
|
||||
print(f"Skipping invalid SYNC command: text index {text_index} must be >=1 and <= {len(text_list)}, time index {time_index} must be >=2 and <= {len(time_list)}")
|
||||
return
|
||||
text_entry = text_list[text_start_idx]
|
||||
time_entry = time_list[time_start_idx]
|
||||
prev_time_entry = time_list[prev_time_idx]
|
||||
delta = self.compute_delta(prev_time_entry['timestamp'], time_entry['timestamp'])
|
||||
new_ts = self.add_delta_to_timestamp(time_entry['timestamp'], delta)
|
||||
self.output.append({
|
||||
'timestamp': new_ts,
|
||||
'text': text_entry['text'],
|
||||
})
|
||||
return entries
|
||||
|
||||
def read_commands(filename):
|
||||
with open(filename, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
commands = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
command = parts[0].upper()
|
||||
if command == 'COPY':
|
||||
if len(parts) != 3:
|
||||
continue
|
||||
try:
|
||||
source = int(parts[1])
|
||||
if source not in (1, 2):
|
||||
continue
|
||||
range_part = parts[2]
|
||||
start_str, end_str = range_part.split('-')
|
||||
start = int(start_str)
|
||||
end = int(end_str)
|
||||
if start > end:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('COPY', source, start, end))
|
||||
elif command == 'MAP':
|
||||
if len(parts) != 6:
|
||||
continue
|
||||
try:
|
||||
text_source = int(parts[1])
|
||||
text_start = int(parts[2])
|
||||
time_source = int(parts[3])
|
||||
time_start = int(parts[4])
|
||||
count = int(parts[5])
|
||||
if text_source not in (1, 2) or time_source not in (1, 2):
|
||||
continue
|
||||
if text_start < 1 or time_start < 1 or count < 1:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('MAP', text_source, text_start, time_source, time_start, count))
|
||||
elif command == 'SYNC':
|
||||
if len(parts) != 4:
|
||||
continue
|
||||
try:
|
||||
text_source = int(parts[1])
|
||||
text_index = int(parts[2])
|
||||
time_index = int(parts[3])
|
||||
if text_source not in (1, 2):
|
||||
continue
|
||||
if text_index < 1 or time_index < 2:
|
||||
continue
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
commands.append(('SYNC', text_source, text_index, time_index))
|
||||
def write_output(self):
|
||||
if self.output_filename == '-':
|
||||
out = sys.stdout
|
||||
else:
|
||||
continue
|
||||
return commands
|
||||
out = open(self.output_filename, 'w', encoding='utf-8')
|
||||
|
||||
def parse_timestamp(ts_str):
|
||||
start_str, end_str = ts_str.split(' --> ')
|
||||
def parse_part(part):
|
||||
parts = part.split(':')
|
||||
hours = int(parts[0])
|
||||
minutes = int(parts[1])
|
||||
sec_ms = parts[2].split(',')
|
||||
seconds = int(sec_ms[0])
|
||||
ms = int(sec_ms[1])
|
||||
return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms
|
||||
return parse_part(start_str), parse_part(end_str)
|
||||
try:
|
||||
for i, entry in enumerate(self.output, start=1):
|
||||
out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n")
|
||||
finally:
|
||||
if out is not sys.stdout:
|
||||
out.close()
|
||||
|
||||
def compute_delta(prev_ts_str, curr_ts_str):
|
||||
prev_start, _ = parse_timestamp(prev_ts_str)
|
||||
curr_start, _ = parse_timestamp(curr_ts_str)
|
||||
return curr_start - prev_start
|
||||
@staticmethod
|
||||
def parse_timestamp(ts_str):
|
||||
start_str, end_str = ts_str.split(' --> ')
|
||||
def parse_part(part):
|
||||
parts = part.split(':')
|
||||
hours = int(parts[0])
|
||||
minutes = int(parts[1])
|
||||
sec_ms = parts[2].split(',')
|
||||
seconds = int(sec_ms[0])
|
||||
ms = int(sec_ms[1])
|
||||
return hours * 3600000 + minutes * 60000 + seconds * 1000 + ms
|
||||
return parse_part(start_str), parse_part(end_str)
|
||||
|
||||
def format_time(ms):
|
||||
total_seconds = ms // 1000
|
||||
ms_part = ms % 1000
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds // 60) % 60
|
||||
seconds = total_seconds % 60
|
||||
return hours, minutes, seconds, ms_part
|
||||
@staticmethod
|
||||
def compute_delta(prev_ts_str, curr_ts_str):
|
||||
prev_start, _ = SRTMerger.parse_timestamp(prev_ts_str)
|
||||
curr_start, _ = SRTMerger.parse_timestamp(curr_ts_str)
|
||||
return curr_start - prev_start
|
||||
|
||||
def format_timestamp(start_ms, end_ms):
|
||||
def format_part(h, m, s, ms):
|
||||
return f"{h:02}:{m:02}:{s:02},{ms:03}"
|
||||
start_h, start_m, start_s, start_ms_part = format_time(start_ms)
|
||||
end_h, end_m, end_s, end_ms_part = format_time(end_ms)
|
||||
return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}"
|
||||
@staticmethod
|
||||
def format_time(ms):
|
||||
total_seconds = ms // 1000
|
||||
ms_part = ms % 1000
|
||||
hours = total_seconds // 3600
|
||||
minutes = (total_seconds // 60) % 60
|
||||
seconds = total_seconds % 60
|
||||
return hours, minutes, seconds, ms_part
|
||||
|
||||
def add_delta_to_timestamp(ts_str, delta):
|
||||
start, end = parse_timestamp(ts_str)
|
||||
return format_timestamp(start + delta, end + delta)
|
||||
@staticmethod
|
||||
def format_timestamp(start_ms, end_ms):
|
||||
def format_part(h, m, s, ms):
|
||||
return f"{h:02}:{m:02}:{s:02},{ms:03}"
|
||||
start_h, start_m, start_s, start_ms_part = SRTMerger.format_time(start_ms)
|
||||
end_h, end_m, end_s, end_ms_part = SRTMerger.format_time(end_ms)
|
||||
return f"{format_part(start_h, start_m, start_s, start_ms_part)} --> {format_part(end_h, end_m, end_s, end_ms_part)}"
|
||||
|
||||
@staticmethod
|
||||
def add_delta_to_timestamp(ts_str, delta):
|
||||
start, end = SRTMerger.parse_timestamp(ts_str)
|
||||
return SRTMerger.format_timestamp(start + delta, end + delta)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) != 5:
|
||||
print("Usage: python merge_srt.py <srt1> <srt2> <commands> <output>")
|
||||
sys.exit(1)
|
||||
srt1_filename, srt2_filename, commands_filename, output_filename = sys.argv[1:5]
|
||||
|
||||
srt1 = parse_srt(srt1_filename)
|
||||
srt2 = parse_srt(srt2_filename)
|
||||
commands = read_commands(commands_filename)
|
||||
|
||||
output = []
|
||||
for cmd in commands:
|
||||
if cmd[0] == 'COPY':
|
||||
source, start, end = cmd[1:]
|
||||
source_list = srt1 if source == 1 else srt2
|
||||
start_idx, end_idx = start - 1, end - 1
|
||||
if start_idx < 0 or end_idx >= len(source_list) or start_idx > end_idx:
|
||||
print(f"Skipping invalid COPY command: source {source}, range {start}-{end}")
|
||||
continue
|
||||
output.extend(source_list[start_idx:end_idx + 1])
|
||||
elif cmd[0] == 'MAP':
|
||||
text_source, text_start, time_source, time_start, count = cmd[1:]
|
||||
text_list = srt1 if text_source == 1 else srt2
|
||||
time_list = srt1 if time_source == 1 else srt2
|
||||
text_start_idx = text_start - 1
|
||||
time_start_idx = time_start - 1
|
||||
if (text_start_idx < 0 or text_start_idx + count > len(text_list) or
|
||||
time_start_idx < 0 or time_start_idx + count > len(time_list)):
|
||||
print(f"Skipping invalid MAP command: text source {text_source}, start {text_start}, count {count} or time source {time_source}, start {time_start}, count {count}")
|
||||
continue
|
||||
for i in range(count):
|
||||
text_entry = text_list[text_start_idx + i]
|
||||
time_entry = time_list[time_start_idx + i]
|
||||
output.append({
|
||||
'index': len(output) + 1,
|
||||
'timestamp': time_entry['timestamp'],
|
||||
'text': text_entry['text'],
|
||||
})
|
||||
elif cmd[0] == 'SYNC':
|
||||
text_source, text_index, time_index = cmd[1:]
|
||||
text_list = srt1 if text_source == 1 else srt2
|
||||
time_list = srt2 if text_source == 1 else srt1
|
||||
text_start_idx = text_index - 1
|
||||
time_start_idx = time_index - 1
|
||||
prev_time_idx = time_start_idx - 1
|
||||
if (text_start_idx < 0 or text_start_idx >= len(text_list) or
|
||||
time_start_idx < 1 or prev_time_idx < 0 or time_start_idx >= len(time_list)):
|
||||
print(f"Skipping invalid SYNC command: text index {text_index} must be >=1 and <= {len(text_list)}, time index {time_index} must be >=2 and <= {len(time_list)}")
|
||||
continue
|
||||
text_entry = text_list[text_start_idx]
|
||||
time_entry = time_list[time_start_idx]
|
||||
prev_time_entry = time_list[prev_time_idx]
|
||||
delta = compute_delta(prev_time_entry['timestamp'], time_entry['timestamp'])
|
||||
new_ts = add_delta_to_timestamp(time_entry['timestamp'], delta)
|
||||
output.append({
|
||||
'index': len(output) + 1,
|
||||
'timestamp': new_ts,
|
||||
'text': text_entry['text'],
|
||||
})
|
||||
|
||||
if output_filename == '-':
|
||||
out = sys.stdout
|
||||
else:
|
||||
out = open(output_filename, 'w', encoding='utf-8')
|
||||
|
||||
try:
|
||||
for i, entry in enumerate(output, start=1):
|
||||
out.write(f"{i}\n{entry['timestamp']}\n{entry['text'].strip()}\n\n")
|
||||
finally:
|
||||
if out is not sys.stdout:
|
||||
out.close()
|
||||
merger = SRTMerger(srt1_filename, srt2_filename, commands_filename, output_filename)
|
||||
merger.process_commands()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue