1use std::borrow::Cow;
7use std::{fmt, io};
8
9use bitcoin::hashes::{sha256, Hash};
10use bitcoin::secp256k1::{self, schnorr, PublicKey};
14use secp256k1_musig::musig;
15
16
17#[derive(Debug, thiserror::Error)]
19pub enum ProtocolDecodingError {
20 #[error("I/O error: {0}")]
21 Io(#[from] io::Error),
22 #[error("invalid protocol encoding: {message}")]
23 Invalid {
24 message: String,
25 source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
26 },
27}
28
29impl ProtocolDecodingError {
30 pub fn invalid(message: impl fmt::Display) -> Self {
32 Self::Invalid {
33 message: message.to_string(),
34 source: None,
35 }
36 }
37
38 pub fn invalid_err<E>(source: E, message: impl fmt::Display) -> Self
40 where
41 E: std::error::Error + Send + Sync + 'static,
42 {
43 Self::Invalid {
44 message: message.to_string(),
45 source: Some(Box::new(source)),
46 }
47 }
48}
49
50impl From<bitcoin::consensus::encode::Error> for ProtocolDecodingError {
51 fn from(e: bitcoin::consensus::encode::Error) -> Self {
52 match e {
53 bitcoin::consensus::encode::Error::Io(e) => Self::Io(e.into()),
54 e => Self::invalid_err(e, "bitcoin protocol decoding error"),
55 }
56 }
57}
58
59impl From<bitcoin::io::Error> for ProtocolDecodingError {
60 fn from(e: bitcoin::io::Error) -> Self {
61 Self::Io(e.into())
62 }
63}
64
65pub trait ProtocolEncoding: Sized {
67 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error>;
70
71 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError>;
73
74 fn serialize(&self) -> Vec<u8> {
76 let mut buf = Vec::new();
77 self.encode(&mut buf).expect("buffers don't produce I/O errors");
78 buf
79 }
80
81 fn deserialize(mut byte_slice: &[u8]) -> Result<Self, ProtocolDecodingError> {
83 Self::decode(&mut byte_slice)
84 }
85
86 fn serialize_hex(&self) -> String {
88 use hex_conservative::Case::Lower;
89 let mut buf = String::new();
90 let mut writer = hex_conservative::display::HexWriter::new(&mut buf, Lower);
91 self.encode(&mut writer).expect("no I/O errors for buffers");
92 buf
93 }
94
95 fn deserialize_hex(hex_str: &str) -> Result<Self, ProtocolDecodingError> {
97 let mut iter = hex_conservative::HexToBytesIter::new(hex_str).map_err(|e| {
98 ProtocolDecodingError::Io(io::Error::new(io::ErrorKind::InvalidData, e))
99 })?;
100 Self::decode(&mut iter)
101 }
102}
103
104pub trait WriteExt: io::Write {
106 fn emit_u8(&mut self, v: u8) -> Result<(), io::Error> {
108 self.write_all(&v.to_le_bytes())
109 }
110
111 fn emit_u16(&mut self, v: u16) -> Result<(), io::Error> {
113 self.write_all(&v.to_le_bytes())
114 }
115
116 fn emit_u32(&mut self, v: u32) -> Result<(), io::Error> {
118 self.write_all(&v.to_le_bytes())
119 }
120
121 fn emit_u64(&mut self, v: u64) -> Result<(), io::Error> {
123 self.write_all(&v.to_le_bytes())
124 }
125
126 fn emit_slice(&mut self, slice: &[u8]) -> Result<(), io::Error> {
128 self.write_all(slice)
129 }
130
131 fn emit_compact_size(&mut self, value: impl Into<u64>) -> Result<usize, io::Error> {
133 let value = value.into();
134 match value {
135 0..=0xFC => {
136 self.emit_u8(value as u8)?;
137 Ok(1)
138 },
139 0xFD..=0xFFFF => {
140 self.emit_u8(0xFD)?;
141 self.emit_u16(value as u16)?;
142 Ok(3)
143 },
144 0x10000..=0xFFFFFFFF => {
145 self.emit_u8(0xFE)?;
146 self.emit_u32(value as u32)?;
147 Ok(5)
148 },
149 _ => {
150 self.emit_u8(0xFF)?;
151 self.emit_u64(value)?;
152 Ok(9)
153 },
154 }
155 }
156}
157
158impl<W: io::Write + ?Sized> WriteExt for W {}
159
160pub trait ReadExt: io::Read {
162 fn read_u8(&mut self) -> Result<u8, io::Error> {
164 let mut buf = [0; 1];
165 self.read_exact(&mut buf[..])?;
166 Ok(u8::from_le_bytes(buf))
167 }
168
169 fn read_u16(&mut self) -> Result<u16, io::Error> {
171 let mut buf = [0; 2];
172 self.read_exact(&mut buf[..])?;
173 Ok(u16::from_le_bytes(buf))
174 }
175
176 fn read_u32(&mut self) -> Result<u32, io::Error> {
178 let mut buf = [0; 4];
179 self.read_exact(&mut buf[..])?;
180 Ok(u32::from_le_bytes(buf))
181 }
182
183 fn read_u64(&mut self) -> Result<u64, io::Error> {
185 let mut buf = [0; 8];
186 self.read_exact(&mut buf[..])?;
187 Ok(u64::from_le_bytes(buf))
188 }
189
190 fn read_slice(&mut self, slice: &mut [u8]) -> Result<(), io::Error> {
192 self.read_exact(slice)
193 }
194
195 fn read_byte_array<const N: usize>(&mut self) -> Result<[u8; N], io::Error> {
197 let mut ret = [0u8; N];
198 self.read_exact(&mut ret)?;
199 Ok(ret)
200 }
201
202 fn read_compact_size(&mut self) -> Result<u64, io::Error> {
204 match self.read_u8()? {
205 0xFF => {
206 let x = self.read_u64()?;
207 if x < 0x1_0000_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
209 } else {
210 Ok(x)
211 }
212 },
213 0xFE => {
214 let x = self.read_u32()?;
215 if x < 0x1_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
217 } else {
218 Ok(x as u64)
219 }
220 },
221 0xFD => {
222 let x = self.read_u16()?;
223 if x < 0xFD { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
225 } else {
226 Ok(x as u64)
227 }
228 },
229 n => Ok(n as u64),
230 }
231 }
232}
233
234impl<R: io::Read + ?Sized> ReadExt for R {}
235
236
237impl ProtocolEncoding for PublicKey {
238 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
239 w.emit_slice(&self.serialize())
240 }
241
242 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
243 let mut buf = [0; secp256k1::constants::PUBLIC_KEY_SIZE];
244 r.read_slice(&mut buf[..])?;
245 PublicKey::from_slice(&buf).map_err(|e| {
246 ProtocolDecodingError::invalid_err(e, "invalid public key")
247 })
248 }
249}
250
251impl ProtocolEncoding for Option<sha256::Hash> {
252 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
253 if let Some(h) = self {
254 w.emit_u8(1)?;
255 w.emit_slice(&h.as_byte_array()[..])
256 } else {
257 w.emit_u8(0)
258 }
259 }
260
261 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
262 let first = r.read_u8()?;
263 if first == 0 {
264 Ok(None)
265 } else if first == 1 {
266 let mut buf = [0u8; 32];
267 r.read_slice(&mut buf)?;
268 Ok(Some(sha256::Hash::from_byte_array(buf)))
269 } else {
270 Err(ProtocolDecodingError::invalid("invalid optional hash prefix byte"))
271 }
272 }
273}
274
275impl ProtocolEncoding for Option<PublicKey> {
276 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
277 if let Some(pk) = self {
278 w.emit_slice(&pk.serialize())
279 } else {
280 w.emit_u8(0)
281 }
282 }
283
284 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
285 let first = r.read_u8()?;
286 if first == 0 {
287 Ok(None)
288 } else {
289 let mut pk = [first; secp256k1::constants::PUBLIC_KEY_SIZE];
290 r.read_slice(&mut pk[1..])?;
291 Ok(Some(PublicKey::from_slice(&pk).map_err(|e| {
292 ProtocolDecodingError::invalid_err(e, "invalid public key")
293 })?))
294 }
295 }
296}
297
298impl ProtocolEncoding for schnorr::Signature {
299 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
300 w.emit_slice(&self.serialize())
301 }
302
303 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
304 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
305 r.read_slice(&mut buf[..])?;
306 schnorr::Signature::from_slice(&buf).map_err(|e| {
307 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
308 })
309 }
310}
311
312impl ProtocolEncoding for Option<schnorr::Signature> {
313 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
314 if let Some(sig) = self {
315 w.emit_slice(&sig.serialize())
316 } else {
317 w.emit_slice(&[0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE])
318 }
319 }
320
321 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
322 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
323 r.read_slice(&mut buf[..])?;
324 if buf == [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE] {
325 Ok(None)
326 } else {
327 Ok(Some(schnorr::Signature::from_slice(&buf).map_err(|e| {
328 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
329 })?))
330 }
331 }
332}
333
334impl ProtocolEncoding for sha256::Hash {
335 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
336 w.emit_slice(&self[..])
337 }
338
339 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
340 let mut buf = [0; sha256::Hash::LEN];
341 r.read_exact(&mut buf[..])?;
342 Ok(sha256::Hash::from_byte_array(buf))
343 }
344}
345
346impl ProtocolEncoding for musig::PublicNonce {
347 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
348 w.emit_slice(&self.serialize())
349 }
350 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
351 Ok(Self::from_byte_array(&r.read_byte_array()?).map_err(|e| {
352 ProtocolDecodingError::invalid_err(e, "invalid musig public nonce")
353 })?)
354 }
355}
356
357impl ProtocolEncoding for musig::PartialSignature {
358 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
359 w.emit_slice(&self.serialize())
360 }
361 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
362 Ok(Self::from_byte_array(&r.read_byte_array()?).map_err(|e| {
363 ProtocolDecodingError::invalid_err(e, "invalid musig public nonce")
364 })?)
365 }
366}
367
368macro_rules! impl_bitcoin_encode {
371 ($name:ty) => {
372 impl ProtocolEncoding for $name {
373 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
374 let mut wrapped = bitcoin::io::FromStd::new(w);
375 bitcoin::consensus::Encodable::consensus_encode(self, &mut wrapped)?;
376 Ok(())
377 }
378
379 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
380 let mut wrapped = bitcoin::io::FromStd::new(r);
381 let ret = bitcoin::consensus::Decodable::consensus_decode(&mut wrapped)?;
382 Ok(ret)
383 }
384 }
385 };
386}
387
388impl_bitcoin_encode!(bitcoin::BlockHash);
389impl_bitcoin_encode!(bitcoin::OutPoint);
390impl_bitcoin_encode!(bitcoin::TxOut);
391
392
393impl<'a, T: ProtocolEncoding + Clone> ProtocolEncoding for Cow<'a, T> {
394 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error> {
395 ProtocolEncoding::encode(self.as_ref(), writer)
396 }
397
398 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError> {
399 Ok(Cow::Owned(ProtocolEncoding::decode(reader)?))
400 }
401}
402
403
404pub mod serde {
405 use std::fmt;
424 use std::borrow::Cow;
425 use std::marker::PhantomData;
426
427 use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
428
429 use super::ProtocolEncoding;
430
431 struct SerWrapper<'a, T>(&'a T);
432
433 impl<'a, T: ProtocolEncoding> Serialize for SerWrapper<'a, T> {
434 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
435 if s.is_human_readable() {
436 s.serialize_str(&self.0.serialize_hex())
437 } else {
438 s.serialize_bytes(&self.0.serialize())
439 }
440 }
441 }
442
443 struct DeWrapper<T>(T);
444
445 impl<'de, T: ProtocolEncoding> Deserialize<'de> for DeWrapper<T> {
446 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
447 if d.is_human_readable() {
448 let s = <Cow<'de, str>>::deserialize(d)?;
449 Ok(DeWrapper(ProtocolEncoding::deserialize_hex(s.as_ref())
450 .map_err(serde::de::Error::custom)?))
451 } else {
452 let b = <Cow<'de, [u8]>>::deserialize(d)?;
453 Ok(DeWrapper(ProtocolEncoding::deserialize(b.as_ref())
454 .map_err(serde::de::Error::custom)?))
455 }
456 }
457 }
458
459 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &T, s: S) -> Result<S::Ok, S::Error> {
460 SerWrapper(v).serialize(s)
461 }
462
463 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<T, D::Error> {
464 Ok(DeWrapper::<T>::deserialize(d)?.0)
465 }
466
467 pub mod vec {
468 use super::*;
469
470 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &[T], s: S) -> Result<S::Ok, S::Error> {
471 let mut seq = s.serialize_seq(Some(v.len()))?;
472 for item in v {
473 ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
474 }
475 ser::SerializeSeq::end(seq)
476 }
477
478 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<Vec<T>, D::Error> {
479 struct Visitor<T>(PhantomData<T>);
480
481 impl<'de, T: ProtocolEncoding> de::Visitor<'de> for Visitor<T> {
482 type Value = Vec<T>;
483
484 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
485 f.write_str("a vector of objects implementing ProtocolEncoding")
486 }
487
488 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
489 let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
490 while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
491 ret.push(v.0);
492 }
493 Ok(ret)
494 }
495 }
496 d.deserialize_seq(Visitor(PhantomData))
497 }
498 }
499
500 pub mod cow {
501 use super::*;
502
503 use std::borrow::Cow;
504
505 pub fn serialize<'a, T, S>(v: &Cow<'a, T>, s: S) -> Result<S::Ok, S::Error>
506 where
507 T: ProtocolEncoding + Clone,
508 S: Serializer,
509 {
510 SerWrapper(v.as_ref()).serialize(s)
511 }
512
513 pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, T>, D::Error>
514 where
515 T: ProtocolEncoding + Clone,
516 D: Deserializer<'d>,
517 {
518 Ok(Cow::Owned(DeWrapper::<T>::deserialize(d)?.0))
519 }
520
521 pub mod vec {
522 use super::*;
523
524 use std::borrow::Cow;
525
526 pub fn serialize<'a, T, S>(v: &Cow<'a, [T]>, s: S) -> Result<S::Ok, S::Error>
527 where
528 T: ProtocolEncoding + Clone,
529 S: Serializer,
530 {
531 let mut seq = s.serialize_seq(Some(v.len()))?;
532 for item in v.as_ref().iter() {
533 ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
534 }
535 ser::SerializeSeq::end(seq)
536 }
537
538 pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, [T]>, D::Error>
539 where
540 T: ProtocolEncoding + Clone,
541 D: Deserializer<'d>,
542 {
543 struct Visitor<T>(PhantomData<T>);
544
545 impl<'de, T: ProtocolEncoding + Clone + 'static> de::Visitor<'de> for Visitor<T> {
546 type Value = Cow<'static, [T]>;
547
548 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
549 f.write_str("a vector of objects implementing ProtocolEncoding")
550 }
551
552 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
553 let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
554 while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
555 ret.push(v.0);
556 }
557 Ok(ret.into())
558 }
559 }
560 d.deserialize_seq(Visitor(PhantomData))
561 }
562 }
563 }
564}
565
566
567#[cfg(any(test, feature = "test-util"))]
568pub mod test {
569 use bitcoin::hex::DisplayHex;
570 use ::serde::{Deserialize, Serialize};
571 use serde_json;
572
573 use super::*;
574
575 pub fn encoding_roundtrip<T>(object: &T)
577 where
578 T: ProtocolEncoding + fmt::Debug + PartialEq,
579 {
580 let encoded = object.serialize();
581 let decoded = T::deserialize(&encoded).unwrap();
582
583 assert_eq!(*object, decoded);
584
585 let re_encoded = decoded.serialize();
586 assert_eq!(encoded.as_hex().to_string(), re_encoded.as_hex().to_string());
587 }
588
589 pub fn json_roundtrip<T>(object: &T)
590 where
591 T: fmt::Debug + PartialEq + Serialize + for<'de> Deserialize<'de>,
592 {
593 let encoded = serde_json::to_string(object).unwrap();
594 let decoded: T = serde_json::from_str(&encoded).unwrap();
595
596 assert_eq!(*object, decoded);
597 }
598}