Skip to content

Instantly share code, notes, and snippets.

@mildsunrise
Last active March 1, 2024 11:43
Show Gist options
  • Save mildsunrise/815ffb594f77dff6e8838f8d41119e12 to your computer and use it in GitHub Desktop.
Save mildsunrise/815ffb594f77dff6e8838f8d41119e12 to your computer and use it in GitHub Desktop.
convert a captured TCP stream of an RTMP connection into FLV
'''
Reads the contents of a TCP stream carrying [one side of] an
RTMP connection, and blindly dumps streams above 0 in an FLV.
$ ./rtmp2flv.py < tcp-stream > video.flv
'''
from enum import IntEnum, unique
from dataclasses import dataclass, field
from typing import NamedTuple, Self, Optional
import io
import sys
def read_exact(st: io.FileIO, n: int) -> bytes:
data = st.read(n)
assert len(data) == n, f'unexpected EOF: expected {n}, found {len(data)}'
return data
# LAYER 1 (chunk header)
class Header(NamedTuple):
''' chunk header (only used by read_chunks) '''
# this is already a bit "cooked", i.e. loses encoding info. in particular:
# - if stream_id >= 64, whether it was encoded in 1 or 2 bytes
fmt: int
stream_id: int # 2 = control
timestamp: int # absolute or delta, depending on fmt
length: int
message_type: int
message_stream_id: int
extended_timestamp: Optional[int] # present when timestamp has maximum value
@classmethod
def parse(cls, st: io.FileIO, chunk_headers: dict[int, Self]) -> Optional[Self]:
# basic header
if not (x := st.read(1)):
return None
x, = x
fmt, stream_id = x >> 6, x & ~(~0 << 6)
if stream_id < 2:
stream_id = 64 + int.from_bytes(read_exact(st, stream_id + 1), 'little')
# message header
prev = chunk_headers.get(stream_id)
timestamp = prev.timestamp if fmt > 2 else \
int.from_bytes(read_exact(st, 3), 'big')
length = prev.length if fmt > 1 else \
int.from_bytes(read_exact(st, 3), 'big')
message_type = prev.message_type if fmt > 1 else st.read(1)[0]
message_stream_id = (prev.message_stream_id if prev else 0) if fmt > 0 else \
int.from_bytes(read_exact(st, 4), 'little')
extended_timestamp = None if timestamp != ~(~0 << 24) else \
int.from_bytes(read_exact(st, 4), 'big')
chunk_headers[stream_id] = self = cls(
fmt=fmt, stream_id=stream_id, timestamp=timestamp, length=length,
message_type=message_type, message_stream_id=message_stream_id,
extended_timestamp=extended_timestamp,
)
return self
def message_timestamp(self, last_timestamp: int) -> int:
ts = self.extended_timestamp
if ts == None:
ts = self.timestamp
if self.fmt > 0:
ts += last_timestamp
return ts
# LAYER 2 (chunks -> messages)
@unique
class MessageType(IntEnum):
SET_CHUNK_SIZE = 0x01
ABORT = 0x02
ACK = 0x03
CONTROL = 0x04
SERVER_BANDWIDTH = 0x05
CLIENT_BANDWIDTH = 0x06
VIRTUAL_CONTROL = 0x07
AUDIO = 0x08
VIDEO = 0x09
DATA_EXTENDED = 0x0F
CONTAINER_EXTENDED = 0x10
COMMAND_EXTENDED = 0x11 # (An AMF3 type command).
DATA = 0x12 # (Invoke (onMetaData info is sent as such))
CONTAINER = 0x13 # Container.
COMMAND = 0x14 # (An AMF0 type command).
UDP = 0x15
AGGREGATE = 0x16
PRESENT = 0x17
class MessageHeader(NamedTuple):
timestamp: int
type: MessageType
stream_id: int
length: int
@dataclass
class Message:
timestamp: int
type: MessageType
stream_id: int
data: bytes
@dataclass
class ChunkReader:
st: io.FileIO
handshake: tuple[bytes, bytes]
chunk_headers: dict[int, Header] = field(default_factory=dict)
last_timestamp: dict[int, int] = field(default_factory=dict)
messages: dict[int, tuple[MessageHeader, bytearray]] = field(default_factory=dict)
max_chunk_size: int = 128
@classmethod
def read_handshake(cls, st: io.FileIO) -> Self:
h0, = st.read(1)
assert h0 == 0x03, f'wrong version {h0}'
h1 = read_exact(st, 1536)
h2 = read_exact(st, 1536)
return cls(st, (h1, h2))
def read_chunk(self) -> Optional[tuple[int, Message]]:
if not (header := Header.parse(self.st, self.chunk_headers)):
raise StopIteration()
stream_id = header.stream_id
last_timestamp = self.last_timestamp.get(stream_id, 0)
message = MessageHeader(
header.message_timestamp(last_timestamp),
MessageType(header.message_type),
header.message_stream_id,
header.length,
)
if stream_id not in self.messages:
self.messages[stream_id] = message, bytearray()
prev, data = self.messages[stream_id]
assert message == prev, f'message fields changed: had {prev}, got {message}'
to_read = min(self.max_chunk_size, header.length - len(data))
data.extend(read_exact(self.st, to_read))
if len(data) == message.length:
del self.messages[stream_id]
self.last_timestamp[stream_id] = message.timestamp
return stream_id, Message(*message[:-1], bytes(data))
def __iter__(self):
return self
def __next__(self):
while not (x := self.read_chunk()): pass
return x
# LAYER 3 (TODO)
if __name__ == '__main__':
rtmp = ChunkReader.read_handshake(sys.stdin.buffer)
# output FLV header
outf = sys.stdout.buffer
outf.write(b'FLV')
outf.write(bytes([1])) # version
outf.write(bytes([0x05])) # flags = audio + video
outf.write((9).to_bytes(4, 'big')) # length of header
outf.write((0).to_bytes(4, 'big')) # first tag length is 0
for stream_id, message in rtmp:
if message.type == MessageType.ABORT:
assert stream_id == 2
assert len(message.data) == 4
del rtmp.messages[int.from_bytes(message.data, 'big')]
if message.type == MessageType.SET_CHUNK_SIZE:
assert stream_id == 2
assert len(message.data) == 4
rtmp.max_chunk_size = int.from_bytes(message.data, 'big')
assert not (rtmp.max_chunk_size >> 31)
if message.stream_id > 0:
print(f'stream_id={stream_id}, ts={message.timestamp}, type={repr(message.type)}, stream={message.stream_id}, data={len(message.data)}', file=sys.stderr)
message.stream_id -= 1
# output FLV packet
prev_pos = outf.tell()
outf.write(bytes([ message.type ])) # packet type
outf.write(len(message.data).to_bytes(3, 'big')) # payload size
outf.write((message.timestamp & ~(~0 << 24)).to_bytes(3, 'big')) # timestamp lower
outf.write((message.timestamp >> 24).to_bytes(1, 'big')) # timestamp upper
outf.write(message.stream_id.to_bytes(3, 'big')) # stream ID
outf.write(message.data) # payload
outf.write((outf.tell() - prev_pos).to_bytes(4, 'big')) # size of previous packet
outf.flush()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment