1use crate::{
2 MctpMedium, MctpMessageHeaderTrait, MctpMessageTrait,
3 MctpPacketError::{self, HeaderParseError},
4 error::{MctpPacketResult, ProtocolError},
5 mctp_command_code::MctpControlCommandCode,
6 mctp_completion_code::MctpCompletionCode,
7};
8
9#[derive(Debug, Default, PartialEq, Eq, Clone)]
10#[cfg_attr(feature = "defmt", derive(defmt::Format))]
11pub struct MctpControlHeader {
12 pub request_bit: bool, pub datagram_bit: bool, pub instance_id: u8, pub command_code: MctpControlCommandCode,
16 pub completion_code: MctpCompletionCode,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone)]
20#[cfg_attr(feature = "defmt", derive(defmt::Format))]
21pub enum MctpControl {
22 SetEndpointIdRequest([u8; 2]),
23 SetEndpointIdResponse([u8; 3]),
24 GetEndpointIdRequest,
25 GetEndpointIdResponse([u8; 3]),
26}
27
28impl MctpMessageHeaderTrait for MctpControlHeader {
29 fn serialize<M: MctpMedium>(self, buffer: &mut [u8]) -> MctpPacketResult<usize, M> {
30 if buffer.len() < 3 {
31 return Err(crate::MctpPacketError::SerializeError(
32 "buffer too small for mctp control header",
33 ));
34 }
35
36 check_request_and_completion_code(self.request_bit, self.completion_code)?;
37
38 buffer[0] = (self.request_bit as u8) << 7
39 | (self.datagram_bit as u8) << 6
40 | (self.instance_id & 0b0001_1111);
41 buffer[1] = self.command_code as u8;
42 buffer[2] = self.completion_code.into();
43 Ok(3)
44 }
45
46 fn deserialize<M: MctpMedium>(buffer: &[u8]) -> MctpPacketResult<(Self, &[u8]), M> {
47 if buffer.len() < 3 {
48 return Err(HeaderParseError("buffer too small for mctp control header"));
49 }
50
51 let request_bit = buffer[0] & 0b1000_0000 != 0;
52 let datagram_bit = buffer[0] & 0b0100_0000 != 0;
53 let instance_id = buffer[0] & 0b0001_1111;
54 let command_code = MctpControlCommandCode::try_from(buffer[1])
55 .map_err(|_| HeaderParseError("invalid mctp command code"))?;
56 let completion_code = MctpCompletionCode::try_from(buffer[2])
57 .map_err(|_| HeaderParseError("invalid mctp completion code"))?;
58
59 check_request_and_completion_code(request_bit, completion_code)?;
60
61 Ok((
62 MctpControlHeader {
63 request_bit,
64 datagram_bit,
65 instance_id,
66 command_code,
67 completion_code,
68 },
69 &buffer[3..],
70 ))
71 }
72}
73
74fn check_request_and_completion_code<M: MctpMedium>(
75 request_bit: bool,
76 completion_code: MctpCompletionCode,
77) -> MctpPacketResult<(), M> {
78 if request_bit && completion_code != MctpCompletionCode::Success {
79 return Err(MctpPacketError::ProtocolError(
80 ProtocolError::CompletionCodeOnRequestMessage(completion_code),
81 ));
82 }
83 Ok(())
84}
85
86impl<'buf> MctpMessageTrait<'buf> for MctpControl {
87 type Header = MctpControlHeader;
88 const MESSAGE_TYPE: u8 = 0x00;
89
90 fn serialize<M: MctpMedium>(self, buffer: &mut [u8]) -> MctpPacketResult<usize, M> {
91 match self {
92 Self::SetEndpointIdRequest(data) => copy_and_check_len(buffer, data),
93 Self::SetEndpointIdResponse(data) => copy_and_check_len(buffer, data),
94 Self::GetEndpointIdRequest => copy_and_check_len(buffer, []),
95 Self::GetEndpointIdResponse(data) => copy_and_check_len(buffer, data),
96 }
97 }
98
99 fn deserialize<M: MctpMedium>(
100 header: &Self::Header,
101 buffer: &'buf [u8],
102 ) -> MctpPacketResult<Self, M> {
103 let message = match (header.request_bit, header.command_code) {
104 (true, MctpControlCommandCode::SetEndpointId) => {
105 Self::SetEndpointIdRequest(try_into_array(buffer)?)
106 }
107 (true, MctpControlCommandCode::GetEndpointId) => Self::GetEndpointIdRequest,
108 (false, MctpControlCommandCode::SetEndpointId) => {
109 Self::SetEndpointIdResponse(try_into_array(buffer)?)
110 }
111 (false, MctpControlCommandCode::GetEndpointId) => {
112 Self::GetEndpointIdResponse(try_into_array(buffer)?)
113 }
114 _ => {
115 return Err(HeaderParseError("invalid mctp control command code"));
116 }
117 };
118 Ok(message)
119 }
120}
121
122fn copy_and_check_len<const N: usize, M: MctpMedium>(
123 buffer: &mut [u8],
124 data: [u8; N],
125) -> MctpPacketResult<usize, M> {
126 if buffer.len() < N {
127 return Err(crate::MctpPacketError::SerializeError(
128 "buffer too small for mctp control message",
129 ));
130 }
131 buffer[..N].copy_from_slice(&data);
132 Ok(N)
133}
134
135fn try_into_array<const N: usize, M: MctpMedium>(buffer: &[u8]) -> MctpPacketResult<[u8; N], M> {
136 if buffer.len() < N {
137 return Err(HeaderParseError(
138 "buffer too small for mctp control message",
139 ));
140 }
141 Ok(buffer[..N].try_into().unwrap())
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::{error::ProtocolError, test_util::TestMedium};
148
149 #[test]
150 fn header_serialize_deserialize_happy_path() {
151 let header = MctpControlHeader {
152 request_bit: true,
153 datagram_bit: false,
154 instance_id: 0b1_1111,
155 command_code: MctpControlCommandCode::GetEndpointId,
156 completion_code: MctpCompletionCode::Success,
157 };
158
159 let mut buf = [0u8; 3];
160 let size = header.clone().serialize::<TestMedium>(&mut buf).unwrap();
161 assert_eq!(size, 3);
162 assert_eq!(
163 buf,
164 [
165 0b1000_0000 | 0b0001_1111, MctpControlCommandCode::GetEndpointId as u8,
167 u8::from(MctpCompletionCode::Success),
168 ]
169 );
170
171 let (parsed, rest) = MctpControlHeader::deserialize::<TestMedium>(&buf).unwrap();
172 assert_eq!(parsed, header);
173 assert_eq!(rest.len(), 0);
174 }
175
176 #[test]
177 fn header_serialize_error_on_completion_code_in_request() {
178 let header = MctpControlHeader {
179 request_bit: true,
180 datagram_bit: false,
181 instance_id: 0,
182 command_code: MctpControlCommandCode::SetEndpointId,
183 completion_code: MctpCompletionCode::Error,
184 };
185
186 let mut buf = [0u8; 3];
187 let err = header.serialize::<TestMedium>(&mut buf).unwrap_err();
188 match err {
189 MctpPacketError::ProtocolError(ProtocolError::CompletionCodeOnRequestMessage(code)) => {
190 assert_eq!(code, MctpCompletionCode::Error)
191 }
192 other => panic!("unexpected error: {:?}", other),
193 }
194 }
195
196 #[rstest::rstest]
197 #[case(MctpControlCommandCode::SetEndpointId, false, MctpControl::SetEndpointIdResponse([0xAA, 0xBB, 0xCC]), &[0xAA, 0xBB, 0xCC])]
198 #[case(MctpControlCommandCode::SetEndpointId, true, MctpControl::SetEndpointIdRequest([0xAA, 0xBB]), &[0xAA, 0xBB])]
199 #[case(MctpControlCommandCode::GetEndpointId, false, MctpControl::GetEndpointIdResponse([0xAA, 0xBB, 0xCC]), &[0xAA, 0xBB, 0xCC])]
200 #[case(MctpControlCommandCode::GetEndpointId, true, MctpControl::GetEndpointIdRequest, &[])]
201 fn message_serialize_deserialize_happy_path(
202 #[case] command_code: MctpControlCommandCode,
203 #[case] request_bit: bool,
204 #[case] message: MctpControl,
205 #[case] expected: &[u8],
206 ) {
207 let mut buf = [0u8; 1024];
208 let size = message.clone().serialize::<TestMedium>(&mut buf).unwrap();
209 assert_eq!(size, expected.len());
210 assert_eq!(&buf[..size], expected);
211
212 let header = MctpControlHeader {
213 request_bit,
214 datagram_bit: false,
215 instance_id: 0,
216 command_code,
217 completion_code: MctpCompletionCode::Success,
218 };
219
220 let parsed = MctpControl::deserialize::<TestMedium>(&header, &buf).unwrap();
221 assert_eq!(parsed, message);
222 }
223
224 #[test]
225 fn message_deserialize_error_on_invalid_command_for_header() {
226 let header = MctpControlHeader {
228 request_bit: true,
229 datagram_bit: false,
230 instance_id: 0,
231 command_code: MctpControlCommandCode::Reserved,
232 completion_code: MctpCompletionCode::Success,
233 };
234
235 let err = MctpControl::deserialize::<TestMedium>(&header, &[]).unwrap_err();
236 match err {
237 MctpPacketError::HeaderParseError(msg) => {
238 assert_eq!(msg, "invalid mctp control command code")
239 }
240 other => panic!("unexpected error: {:?}", other),
241 }
242 }
243}