Last active
March 1, 2024 11:43
-
-
Save mildsunrise/815ffb594f77dff6e8838f8d41119e12 to your computer and use it in GitHub Desktop.
convert a captured TCP stream of an RTMP connection into FLV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Reads the contents of a TCP stream carrying [one side of] an | |
RTMP connection, and blindly dumps streams above 0 in an FLV. | |
$ ./rtmp2flv.py < tcp-stream > video.flv | |
''' | |
from enum import IntEnum, unique | |
from dataclasses import dataclass, field | |
from typing import NamedTuple, Self, Optional | |
import io | |
import sys | |
def read_exact(st: io.FileIO, n: int) -> bytes: | |
data = st.read(n) | |
assert len(data) == n, f'unexpected EOF: expected {n}, found {len(data)}' | |
return data | |
# LAYER 1 (chunk header) | |
class Header(NamedTuple): | |
''' chunk header (only used by read_chunks) ''' | |
# this is already a bit "cooked", i.e. loses encoding info. in particular: | |
# - if stream_id >= 64, whether it was encoded in 1 or 2 bytes | |
fmt: int | |
stream_id: int # 2 = control | |
timestamp: int # absolute or delta, depending on fmt | |
length: int | |
message_type: int | |
message_stream_id: int | |
extended_timestamp: Optional[int] # present when timestamp has maximum value | |
@classmethod | |
def parse(cls, st: io.FileIO, chunk_headers: dict[int, Self]) -> Optional[Self]: | |
# basic header | |
if not (x := st.read(1)): | |
return None | |
x, = x | |
fmt, stream_id = x >> 6, x & ~(~0 << 6) | |
if stream_id < 2: | |
stream_id = 64 + int.from_bytes(read_exact(st, stream_id + 1), 'little') | |
# message header | |
prev = chunk_headers.get(stream_id) | |
timestamp = prev.timestamp if fmt > 2 else \ | |
int.from_bytes(read_exact(st, 3), 'big') | |
length = prev.length if fmt > 1 else \ | |
int.from_bytes(read_exact(st, 3), 'big') | |
message_type = prev.message_type if fmt > 1 else st.read(1)[0] | |
message_stream_id = (prev.message_stream_id if prev else 0) if fmt > 0 else \ | |
int.from_bytes(read_exact(st, 4), 'little') | |
extended_timestamp = None if timestamp != ~(~0 << 24) else \ | |
int.from_bytes(read_exact(st, 4), 'big') | |
chunk_headers[stream_id] = self = cls( | |
fmt=fmt, stream_id=stream_id, timestamp=timestamp, length=length, | |
message_type=message_type, message_stream_id=message_stream_id, | |
extended_timestamp=extended_timestamp, | |
) | |
return self | |
def message_timestamp(self, last_timestamp: int) -> int: | |
ts = self.extended_timestamp | |
if ts == None: | |
ts = self.timestamp | |
if self.fmt > 0: | |
ts += last_timestamp | |
return ts | |
# LAYER 2 (chunks -> messages) | |
@unique | |
class MessageType(IntEnum): | |
SET_CHUNK_SIZE = 0x01 | |
ABORT = 0x02 | |
ACK = 0x03 | |
CONTROL = 0x04 | |
SERVER_BANDWIDTH = 0x05 | |
CLIENT_BANDWIDTH = 0x06 | |
VIRTUAL_CONTROL = 0x07 | |
AUDIO = 0x08 | |
VIDEO = 0x09 | |
DATA_EXTENDED = 0x0F | |
CONTAINER_EXTENDED = 0x10 | |
COMMAND_EXTENDED = 0x11 # (An AMF3 type command). | |
DATA = 0x12 # (Invoke (onMetaData info is sent as such)) | |
CONTAINER = 0x13 # Container. | |
COMMAND = 0x14 # (An AMF0 type command). | |
UDP = 0x15 | |
AGGREGATE = 0x16 | |
PRESENT = 0x17 | |
class MessageHeader(NamedTuple): | |
timestamp: int | |
type: MessageType | |
stream_id: int | |
length: int | |
@dataclass | |
class Message: | |
timestamp: int | |
type: MessageType | |
stream_id: int | |
data: bytes | |
@dataclass | |
class ChunkReader: | |
st: io.FileIO | |
handshake: tuple[bytes, bytes] | |
chunk_headers: dict[int, Header] = field(default_factory=dict) | |
last_timestamp: dict[int, int] = field(default_factory=dict) | |
messages: dict[int, tuple[MessageHeader, bytearray]] = field(default_factory=dict) | |
max_chunk_size: int = 128 | |
@classmethod | |
def read_handshake(cls, st: io.FileIO) -> Self: | |
h0, = st.read(1) | |
assert h0 == 0x03, f'wrong version {h0}' | |
h1 = read_exact(st, 1536) | |
h2 = read_exact(st, 1536) | |
return cls(st, (h1, h2)) | |
def read_chunk(self) -> Optional[tuple[int, Message]]: | |
if not (header := Header.parse(self.st, self.chunk_headers)): | |
raise StopIteration() | |
stream_id = header.stream_id | |
last_timestamp = self.last_timestamp.get(stream_id, 0) | |
message = MessageHeader( | |
header.message_timestamp(last_timestamp), | |
MessageType(header.message_type), | |
header.message_stream_id, | |
header.length, | |
) | |
if stream_id not in self.messages: | |
self.messages[stream_id] = message, bytearray() | |
prev, data = self.messages[stream_id] | |
assert message == prev, f'message fields changed: had {prev}, got {message}' | |
to_read = min(self.max_chunk_size, header.length - len(data)) | |
data.extend(read_exact(self.st, to_read)) | |
if len(data) == message.length: | |
del self.messages[stream_id] | |
self.last_timestamp[stream_id] = message.timestamp | |
return stream_id, Message(*message[:-1], bytes(data)) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
while not (x := self.read_chunk()): pass | |
return x | |
# LAYER 3 (TODO) | |
if __name__ == '__main__': | |
rtmp = ChunkReader.read_handshake(sys.stdin.buffer) | |
# output FLV header | |
outf = sys.stdout.buffer | |
outf.write(b'FLV') | |
outf.write(bytes([1])) # version | |
outf.write(bytes([0x05])) # flags = audio + video | |
outf.write((9).to_bytes(4, 'big')) # length of header | |
outf.write((0).to_bytes(4, 'big')) # first tag length is 0 | |
for stream_id, message in rtmp: | |
if message.type == MessageType.ABORT: | |
assert stream_id == 2 | |
assert len(message.data) == 4 | |
del rtmp.messages[int.from_bytes(message.data, 'big')] | |
if message.type == MessageType.SET_CHUNK_SIZE: | |
assert stream_id == 2 | |
assert len(message.data) == 4 | |
rtmp.max_chunk_size = int.from_bytes(message.data, 'big') | |
assert not (rtmp.max_chunk_size >> 31) | |
if message.stream_id > 0: | |
print(f'stream_id={stream_id}, ts={message.timestamp}, type={repr(message.type)}, stream={message.stream_id}, data={len(message.data)}', file=sys.stderr) | |
message.stream_id -= 1 | |
# output FLV packet | |
prev_pos = outf.tell() | |
outf.write(bytes([ message.type ])) # packet type | |
outf.write(len(message.data).to_bytes(3, 'big')) # payload size | |
outf.write((message.timestamp & ~(~0 << 24)).to_bytes(3, 'big')) # timestamp lower | |
outf.write((message.timestamp >> 24).to_bytes(1, 'big')) # timestamp upper | |
outf.write(message.stream_id.to_bytes(3, 'big')) # stream ID | |
outf.write(message.data) # payload | |
outf.write((outf.tell() - prev_pos).to_bytes(4, 'big')) # size of previous packet | |
outf.flush() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment