ark/
encode.rs

1//!
2//! Definitions of protocol encodings.
3//!
4
5
6use std::borrow::Cow;
7use std::{fmt, io};
8
9use bitcoin::hashes::{sha256, Hash};
10// We use bitcoin::io::{Read, Write} here but we shouldn't have to.
11// I created this issue in the hope that rust-bitcoin fixes this nuisance:
12//  https://github.com/rust-bitcoin/rust-bitcoin/issues/4530
13use bitcoin::secp256k1::{self, schnorr, PublicKey};
14use secp256k1_musig::musig;
15
16
17/// Error occuring during protocol decoding.
18#[derive(Debug, thiserror::Error)]
19pub enum ProtocolDecodingError {
20	#[error("I/O error: {0}")]
21	Io(#[from] io::Error),
22	#[error("invalid protocol encoding: {message}")]
23	Invalid {
24		message: String,
25		source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
26	},
27}
28
29impl ProtocolDecodingError {
30	/// Create a new [ProtocolDecodingError::Invalid] with the given message.
31	pub fn invalid(message: impl fmt::Display) -> Self {
32		Self::Invalid {
33			message: message.to_string(),
34			source: None,
35		}
36	}
37
38	/// Create a new [ProtocolDecodingError::Invalid] with the given message and source error.
39	pub fn invalid_err<E>(source: E, message: impl fmt::Display) -> Self
40	where
41		E: std::error::Error + Send + Sync + 'static,
42	{
43		Self::Invalid {
44			message: message.to_string(),
45			source: Some(Box::new(source)),
46		}
47	}
48}
49
50impl From<bitcoin::consensus::encode::Error> for ProtocolDecodingError {
51	fn from(e: bitcoin::consensus::encode::Error) -> Self {
52		match e {
53			bitcoin::consensus::encode::Error::Io(e) => Self::Io(e.into()),
54			e => Self::invalid_err(e, "bitcoin protocol decoding error"),
55		}
56	}
57}
58
59impl From<bitcoin::io::Error> for ProtocolDecodingError {
60	fn from(e: bitcoin::io::Error) -> Self {
61	    Self::Io(e.into())
62	}
63}
64
65/// Trait for encoding objects according to the bark protocol encoding.
66pub trait ProtocolEncoding: Sized {
67	/// Encode the object into the writer.
68	//TODO(stevenroose) return nb of bytes written like bitcoin::consensus::Encodable does?
69	fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error>;
70
71	/// Decode the object from the writer.
72	fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError>;
73
74	/// Serialize the object into a byte vector.
75	fn serialize(&self) -> Vec<u8> {
76		let mut buf = Vec::new();
77		self.encode(&mut buf).expect("buffers don't produce I/O errors");
78		buf
79	}
80
81	/// Deserialize object from the given byte slice.
82	fn deserialize(mut byte_slice: &[u8]) -> Result<Self, ProtocolDecodingError> {
83		Self::decode(&mut byte_slice)
84	}
85
86	/// Serialize the object to a lowercase hex string.
87	fn serialize_hex(&self) -> String {
88		use hex_conservative::Case::Lower;
89		let mut buf = String::new();
90		let mut writer = hex_conservative::display::HexWriter::new(&mut buf, Lower);
91		self.encode(&mut writer).expect("no I/O errors for buffers");
92		buf
93	}
94
95	/// Deserialize object from hex slice.
96	fn deserialize_hex(hex_str: &str) -> Result<Self, ProtocolDecodingError> {
97		let mut iter = hex_conservative::HexToBytesIter::new(hex_str).map_err(|e| {
98			ProtocolDecodingError::Io(io::Error::new(io::ErrorKind::InvalidData, e))
99		})?;
100		Self::decode(&mut iter)
101	}
102}
103
104/// Utility trait to write some primitive values into our encoding format.
105pub trait WriteExt: io::Write {
106	/// Write an 8-bit unsigned integer in little-endian.
107	fn emit_u8(&mut self, v: u8) -> Result<(), io::Error> {
108		self.write_all(&v.to_le_bytes())
109	}
110
111	/// Write a 16-bit unsigned integer in little-endian.
112	fn emit_u16(&mut self, v: u16) -> Result<(), io::Error> {
113		self.write_all(&v.to_le_bytes())
114	}
115
116	/// Write a 32-bit unsigned integer in little-endian.
117	fn emit_u32(&mut self, v: u32) -> Result<(), io::Error> {
118		self.write_all(&v.to_le_bytes())
119	}
120
121	/// Write a 64-bit unsigned integer in little-endian.
122	fn emit_u64(&mut self, v: u64) -> Result<(), io::Error> {
123		self.write_all(&v.to_le_bytes())
124	}
125
126	/// Write the entire slice to the writer.
127	fn emit_slice(&mut self, slice: &[u8]) -> Result<(), io::Error> {
128		self.write_all(slice)
129	}
130
131	/// Write a value in compact size aka "VarInt" encoding.
132	fn emit_compact_size(&mut self, value: impl Into<u64>) -> Result<usize, io::Error> {
133		let value = value.into();
134		match value {
135			0..=0xFC => {
136				self.emit_u8(value as u8)?;
137				Ok(1)
138			},
139			0xFD..=0xFFFF => {
140				self.emit_u8(0xFD)?;
141				self.emit_u16(value as u16)?;
142				Ok(3)
143			},
144			0x10000..=0xFFFFFFFF => {
145				self.emit_u8(0xFE)?;
146				self.emit_u32(value as u32)?;
147				Ok(5)
148			},
149			_ => {
150				self.emit_u8(0xFF)?;
151				self.emit_u64(value)?;
152				Ok(9)
153			},
154		}
155	}
156}
157
158impl<W: io::Write + ?Sized> WriteExt for W {}
159
160/// Utility trait to read some primitive values into our encoding format.
161pub trait ReadExt: io::Read {
162	/// Read an 8-bit unsigned integer in little-endian.
163	fn read_u8(&mut self) -> Result<u8, io::Error> {
164		let mut buf = [0; 1];
165		self.read_exact(&mut buf[..])?;
166		Ok(u8::from_le_bytes(buf))
167	}
168
169	/// Read a 16-bit unsigned integer in little-endian.
170	fn read_u16(&mut self) -> Result<u16, io::Error> {
171		let mut buf = [0; 2];
172		self.read_exact(&mut buf[..])?;
173		Ok(u16::from_le_bytes(buf))
174	}
175
176	/// Read a 32-bit unsigned integer in little-endian.
177	fn read_u32(&mut self) -> Result<u32, io::Error> {
178		let mut buf = [0; 4];
179		self.read_exact(&mut buf[..])?;
180		Ok(u32::from_le_bytes(buf))
181	}
182
183	/// Read a 64-bit unsigned integer in little-endian.
184	fn read_u64(&mut self) -> Result<u64, io::Error> {
185		let mut buf = [0; 8];
186		self.read_exact(&mut buf[..])?;
187		Ok(u64::from_le_bytes(buf))
188	}
189
190	/// Read from the reader to fill the entire slice.
191	fn read_slice(&mut self, slice: &mut [u8]) -> Result<(), io::Error> {
192		self.read_exact(slice)
193	}
194
195	/// Read a byte array
196	fn read_byte_array<const N: usize>(&mut self) -> Result<[u8; N], io::Error> {
197		let mut ret = [0u8; N];
198		self.read_exact(&mut ret)?;
199		Ok(ret)
200	}
201
202	/// Read a value in compact size aka "VarInt" encoding.
203	fn read_compact_size(&mut self) -> Result<u64, io::Error> {
204		match self.read_u8()? {
205			0xFF => {
206				let x = self.read_u64()?;
207				if x < 0x1_0000_0000 { // I.e., would have fit in a `u32`.
208					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
209				} else {
210					Ok(x)
211				}
212			},
213			0xFE => {
214				let x = self.read_u32()?;
215				if x < 0x1_0000 { // I.e., would have fit in a `u16`.
216					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
217				} else {
218					Ok(x as u64)
219				}
220			},
221			0xFD => {
222				let x = self.read_u16()?;
223				if x < 0xFD { // Could have been encoded as a `u8`.
224					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
225				} else {
226					Ok(x as u64)
227				}
228			},
229			n => Ok(n as u64),
230		}
231	}
232}
233
234impl<R: io::Read + ?Sized> ReadExt for R {}
235
236
237impl ProtocolEncoding for PublicKey {
238	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
239		w.emit_slice(&self.serialize())
240	}
241
242	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
243		let mut buf = [0; secp256k1::constants::PUBLIC_KEY_SIZE];
244		r.read_slice(&mut buf[..])?;
245		PublicKey::from_slice(&buf).map_err(|e| {
246			ProtocolDecodingError::invalid_err(e, "invalid public key")
247		})
248	}
249}
250
251impl ProtocolEncoding for Option<sha256::Hash> {
252	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
253		if let Some(h) = self {
254			w.emit_u8(1)?;
255			w.emit_slice(&h.as_byte_array()[..])
256		} else {
257			w.emit_u8(0)
258		}
259	}
260
261	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
262		let first = r.read_u8()?;
263		if first == 0 {
264			Ok(None)
265		} else if first == 1 {
266			let mut buf = [0u8; 32];
267			r.read_slice(&mut buf)?;
268			Ok(Some(sha256::Hash::from_byte_array(buf)))
269		} else {
270			Err(ProtocolDecodingError::invalid("invalid optional hash prefix byte"))
271		}
272	}
273}
274
275impl ProtocolEncoding for Option<PublicKey> {
276	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
277		if let Some(pk) = self {
278			w.emit_slice(&pk.serialize())
279		} else {
280			w.emit_u8(0)
281		}
282	}
283
284	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
285		let first = r.read_u8()?;
286		if first == 0 {
287			Ok(None)
288		} else {
289			let mut pk = [first; secp256k1::constants::PUBLIC_KEY_SIZE];
290			r.read_slice(&mut pk[1..])?;
291			Ok(Some(PublicKey::from_slice(&pk).map_err(|e| {
292				ProtocolDecodingError::invalid_err(e, "invalid public key")
293			})?))
294		}
295	}
296}
297
298impl ProtocolEncoding for schnorr::Signature {
299	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
300		w.emit_slice(&self.serialize())
301	}
302
303	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
304		let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
305		r.read_slice(&mut buf[..])?;
306		schnorr::Signature::from_slice(&buf).map_err(|e| {
307			ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
308		})
309	}
310}
311
312impl ProtocolEncoding for Option<schnorr::Signature> {
313	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
314		if let Some(sig) = self {
315			w.emit_slice(&sig.serialize())
316		} else {
317			w.emit_slice(&[0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE])
318		}
319	}
320
321	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
322		let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
323		r.read_slice(&mut buf[..])?;
324		if buf == [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE] {
325			Ok(None)
326		} else {
327			Ok(Some(schnorr::Signature::from_slice(&buf).map_err(|e| {
328				ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
329			})?))
330		}
331	}
332}
333
334impl ProtocolEncoding for sha256::Hash {
335	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
336		w.emit_slice(&self[..])
337	}
338
339	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
340		let mut buf = [0; sha256::Hash::LEN];
341		r.read_exact(&mut buf[..])?;
342		Ok(sha256::Hash::from_byte_array(buf))
343	}
344}
345
346impl ProtocolEncoding for musig::PublicNonce {
347	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
348	    w.emit_slice(&self.serialize())
349	}
350	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
351		Ok(Self::from_byte_array(&r.read_byte_array()?).map_err(|e| {
352			ProtocolDecodingError::invalid_err(e, "invalid musig public nonce")
353		})?)
354	}
355}
356
357impl ProtocolEncoding for musig::PartialSignature {
358	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
359	    w.emit_slice(&self.serialize())
360	}
361	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
362		Ok(Self::from_byte_array(&r.read_byte_array()?).map_err(|e| {
363			ProtocolDecodingError::invalid_err(e, "invalid musig public nonce")
364		})?)
365	}
366}
367
368/// A macro to implement our [ProtocolEncoding] for a rust-bitcoin type that
369/// implements their `consensus::Encodable/Decodable` traits.
370macro_rules! impl_bitcoin_encode {
371	($name:ty) => {
372		impl ProtocolEncoding for $name {
373			fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
374				let mut wrapped = bitcoin::io::FromStd::new(w);
375				bitcoin::consensus::Encodable::consensus_encode(self, &mut wrapped)?;
376				Ok(())
377			}
378
379			fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
380				let mut wrapped = bitcoin::io::FromStd::new(r);
381				let ret = bitcoin::consensus::Decodable::consensus_decode(&mut wrapped)?;
382				Ok(ret)
383			}
384		}
385	};
386}
387
388impl_bitcoin_encode!(bitcoin::BlockHash);
389impl_bitcoin_encode!(bitcoin::OutPoint);
390impl_bitcoin_encode!(bitcoin::TxOut);
391
392
393impl<'a, T: ProtocolEncoding + Clone> ProtocolEncoding for Cow<'a, T> {
394	fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error> {
395	    ProtocolEncoding::encode(self.as_ref(), writer)
396	}
397
398	fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError> {
399	    Ok(Cow::Owned(ProtocolEncoding::decode(reader)?))
400	}
401}
402
403
404pub mod serde {
405	//! Module that helps to encode [ProtocolEncoding] objects with serde.
406	//!
407	//! By default, the objects will be encoded as bytes for regular serializers,
408	//! and as hex for human-readable serializers.
409	//!
410	//! Can be used as follows:
411	//! ```no_run
412	//! # use ark::Vtxo;
413	//! # use serde::{Serialize, Deserialize};
414	//! #[derive(Serialize, Deserialize)]
415	//! struct SomeStruct {
416	//! 	#[serde(with = "ark::encode::serde")]
417	//! 	single: Vtxo,
418	//! 	#[serde(with = "ark::encode::serde::vec")]
419	//! 	multiple: Vec<Vtxo>,
420	//! }
421	//! ```
422
423	use std::fmt;
424	use std::borrow::Cow;
425	use std::marker::PhantomData;
426
427	use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
428
429	use super::ProtocolEncoding;
430
431	struct SerWrapper<'a, T>(&'a T);
432
433	impl<'a, T: ProtocolEncoding> Serialize for SerWrapper<'a, T> {
434		fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
435			if s.is_human_readable() {
436				s.serialize_str(&self.0.serialize_hex())
437			} else {
438				s.serialize_bytes(&self.0.serialize())
439			}
440		}
441	}
442
443	struct DeWrapper<T>(T);
444
445	impl<'de, T: ProtocolEncoding> Deserialize<'de> for DeWrapper<T> {
446		fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
447			if d.is_human_readable() {
448				let s = <Cow<'de, str>>::deserialize(d)?;
449				Ok(DeWrapper(ProtocolEncoding::deserialize_hex(s.as_ref())
450					.map_err(serde::de::Error::custom)?))
451			} else {
452				let b = <Cow<'de, [u8]>>::deserialize(d)?;
453				Ok(DeWrapper(ProtocolEncoding::deserialize(b.as_ref())
454					.map_err(serde::de::Error::custom)?))
455			}
456		}
457	}
458
459	pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &T, s: S) -> Result<S::Ok, S::Error> {
460		SerWrapper(v).serialize(s)
461	}
462
463	pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<T, D::Error> {
464		Ok(DeWrapper::<T>::deserialize(d)?.0)
465	}
466
467	pub mod vec {
468		use super::*;
469
470		pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &[T], s: S) -> Result<S::Ok, S::Error> {
471			let mut seq = s.serialize_seq(Some(v.len()))?;
472			for item in v {
473				ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
474			}
475			ser::SerializeSeq::end(seq)
476		}
477
478		pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<Vec<T>, D::Error> {
479			struct Visitor<T>(PhantomData<T>);
480
481			impl<'de, T: ProtocolEncoding> de::Visitor<'de> for Visitor<T> {
482				type Value = Vec<T>;
483
484				fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
485					f.write_str("a vector of objects implementing ProtocolEncoding")
486				}
487
488				fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
489					let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
490					while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
491						ret.push(v.0);
492					}
493					Ok(ret)
494				}
495			}
496			d.deserialize_seq(Visitor(PhantomData))
497		}
498	}
499
500	pub mod cow {
501		use super::*;
502
503		use std::borrow::Cow;
504
505		pub fn serialize<'a, T, S>(v: &Cow<'a, T>, s: S) -> Result<S::Ok, S::Error>
506		where
507			T: ProtocolEncoding + Clone,
508			S: Serializer,
509		{
510			SerWrapper(v.as_ref()).serialize(s)
511		}
512
513		pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, T>, D::Error>
514		where
515			T: ProtocolEncoding + Clone,
516			D: Deserializer<'d>,
517		{
518			Ok(Cow::Owned(DeWrapper::<T>::deserialize(d)?.0))
519		}
520
521		pub mod vec {
522			use super::*;
523
524			use std::borrow::Cow;
525
526			pub fn serialize<'a, T, S>(v: &Cow<'a, [T]>, s: S) -> Result<S::Ok, S::Error>
527			where
528				T: ProtocolEncoding + Clone,
529				S: Serializer,
530			{
531				let mut seq = s.serialize_seq(Some(v.len()))?;
532				for item in v.as_ref().iter() {
533					ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
534				}
535				ser::SerializeSeq::end(seq)
536			}
537
538			pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, [T]>, D::Error>
539			where
540				T: ProtocolEncoding + Clone,
541				D: Deserializer<'d>,
542			{
543				struct Visitor<T>(PhantomData<T>);
544
545				impl<'de, T: ProtocolEncoding + Clone + 'static> de::Visitor<'de> for Visitor<T> {
546					type Value = Cow<'static, [T]>;
547
548					fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
549						f.write_str("a vector of objects implementing ProtocolEncoding")
550					}
551
552					fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
553						let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
554						while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
555							ret.push(v.0);
556						}
557						Ok(ret.into())
558					}
559				}
560				d.deserialize_seq(Visitor(PhantomData))
561			}
562		}
563	}
564}
565
566
567#[cfg(any(test, feature = "test-util"))]
568pub mod test {
569	use bitcoin::hex::DisplayHex;
570	use ::serde::{Deserialize, Serialize};
571	use serde_json;
572
573	use super::*;
574
575	/// Test that the object's encoding round-trips.
576	pub fn encoding_roundtrip<T>(object: &T)
577	where
578		T: ProtocolEncoding + fmt::Debug + PartialEq,
579	{
580		let encoded = object.serialize();
581		let decoded = T::deserialize(&encoded).unwrap();
582
583		assert_eq!(*object, decoded);
584
585		let re_encoded = decoded.serialize();
586		assert_eq!(encoded.as_hex().to_string(), re_encoded.as_hex().to_string());
587	}
588
589	pub fn json_roundtrip<T>(object: &T)
590	where
591		T: fmt::Debug + PartialEq + Serialize + for<'de> Deserialize<'de>,
592	{
593		let encoded = serde_json::to_string(object).unwrap();
594		let decoded: T = serde_json::from_str(&encoded).unwrap();
595
596		assert_eq!(*object, decoded);
597	}
598}