use "../ssh_crypto"
use "../ssh_error"
class SshKexStateMachine
let _role: SshRole
new create(role: SshRole) =>
_role = role
fun ref generate_kexinit(prefs: SshAlgorithmPreferences val): Array[U8] val =>
"""Generate our SSH_MSG_KEXINIT payload with random 16-byte cookie."""
let cookie = SshRandom.random_bytes(16)
SshMessages.kexinit(prefs, consume cookie)
fun ref receive_kexinit(their_payload: Array[U8] val,
our_prefs: SshAlgorithmPreferences val):
(SshNegotiatedAlgorithms val | SshTransportError)
=>
"""Parse their KEXINIT, negotiate algorithms."""
try
match SshMessages.decode_kexinit(their_payload)?
| let their_prefs: SshAlgorithmPreferences val =>
// Client preferences go first in negotiation per RFC 4253
match _role
| SshRoleClient =>
SshAlgorithmNegotiation.negotiate(our_prefs, their_prefs)
| SshRoleServer =>
SshAlgorithmNegotiation.negotiate(their_prefs, our_prefs)
end
| None =>
SshProtocolVersionMismatch
end
else
SshPacketCorrupt
end
fun ref derive_keys(shared_secret: Array[U8] val,
exchange_hash: Array[U8] val, session_id: Array[U8] val,
negotiated: SshNegotiatedAlgorithms val):
SshDerivedKeys val
=>
"""
Derive encryption keys per RFC 4253 section 7.2.
Each key is: HASH(K || H || X || session_id)
where K = shared secret (mpint), H = exchange hash, X = single letter.
"""
let iv_c2s = _derive_key(shared_secret, exchange_hash, 'A', session_id)
let iv_s2c = _derive_key(shared_secret, exchange_hash, 'B', session_id)
let enc_key_c2s = _derive_key(shared_secret, exchange_hash, 'C', session_id)
let enc_key_s2c = _derive_key(shared_secret, exchange_hash, 'D', session_id)
let mac_key_c2s = _derive_key(shared_secret, exchange_hash, 'E', session_id)
let mac_key_s2c = _derive_key(shared_secret, exchange_hash, 'F', session_id)
SshDerivedKeys(iv_c2s, iv_s2c, enc_key_c2s, enc_key_s2c,
mac_key_c2s, mac_key_s2c)
fun _derive_key(shared_secret: Array[U8] val, exchange_hash: Array[U8] val,
letter: U8, session_id: Array[U8] val): Array[U8] val
=>
"""Compute HASH(K_mpint || H || letter || session_id) using SHA-256."""
let mpint = _encode_mpint(shared_secret)
let input = recover val
let buf = Array[U8]
for b in mpint.values() do buf.push(b) end
for b in exchange_hash.values() do buf.push(b) end
buf.push(letter)
for b in session_id.values() do buf.push(b) end
buf
end
SshHash.sha256(input)
fun _encode_mpint(value: Array[U8] val): Array[U8] val =>
"""Encode value as SSH mpint (big-endian with length prefix, leading zero if high bit set)."""
recover val
let buf = Array[U8]
if value.size() == 0 then
buf.push(0); buf.push(0); buf.push(0); buf.push(0)
else
try
if (value(0)? and 0x80) != 0 then
let len = (value.size() + 1).u32()
buf.push((len >> 24).u8()); buf.push((len >> 16).u8())
buf.push((len >> 8).u8()); buf.push(len.u8())
buf.push(0)
else
let len = value.size().u32()
buf.push((len >> 24).u8()); buf.push((len >> 16).u8())
buf.push((len >> 8).u8()); buf.push(len.u8())
end
end
for b in value.values() do buf.push(b) end
end
buf
end
class val SshDerivedKeys
let iv_c2s: Array[U8] val
let iv_s2c: Array[U8] val
let enc_key_c2s: Array[U8] val
let enc_key_s2c: Array[U8] val
let mac_key_c2s: Array[U8] val
let mac_key_s2c: Array[U8] val
new val create(iv_c2s': Array[U8] val, iv_s2c': Array[U8] val,
enc_key_c2s': Array[U8] val, enc_key_s2c': Array[U8] val,
mac_key_c2s': Array[U8] val, mac_key_s2c': Array[U8] val)
=>
iv_c2s = iv_c2s'; iv_s2c = iv_s2c'
enc_key_c2s = enc_key_c2s'; enc_key_s2c = enc_key_s2c'
mac_key_c2s = mac_key_c2s'; mac_key_s2c = mac_key_s2c'