From 373ab7c4fc560f226fb36b36a68eea88421990a3 Mon Sep 17 00:00:00 2001 From: Joakim Holm Date: Tue, 9 Jan 2024 14:44:22 +0100 Subject: [PATCH] Make Encryption an interface instead of an union --- grawlix/encryption.py | 50 ++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/grawlix/encryption.py b/grawlix/encryption.py index c992f8f..b4172ed 100644 --- a/grawlix/encryption.py +++ b/grawlix/encryption.py @@ -1,5 +1,5 @@ from Crypto.Cipher import AES -from typing import Union +from typing import Union, Protocol from dataclasses import dataclass @@ -8,6 +8,10 @@ class AESEncryption: key: bytes iv: bytes + def decrypt(self, data: bytes) -> bytes: + cipher = AES.new(self.key, AES.MODE_CBC, self.iv) + return cipher.decrypt(data) + @dataclass(slots=True) class AESCTREncryption: @@ -15,16 +19,31 @@ class AESCTREncryption: nonce: bytes initial_value: bytes + def decrypt(self, data: bytes) -> bytes: + cipher = AES.new( + key = self.key, + mode = AES.MODE_CTR, + nonce = self.nonce, + initial_value = self.initial_value + ) + return cipher.decrypt(data) + @dataclass(slots=True) class XOrEncryption: key: bytes -Encryption = Union[ - AESCTREncryption, - AESEncryption, - XOrEncryption -] + def decrypt(self, data: bytes) -> bytes: + key_length = len(self.key) + decoded = [] + for i in range(0, len(data)): + decoded.append(data[i] ^ self.key[i % key_length]) + return bytes(decoded) + + +class Encryption(Protocol): + def decrypt(self, data: bytes) -> bytes: ... + def decrypt(data: bytes, encryption: Encryption) -> bytes: """ @@ -34,21 +53,4 @@ def decrypt(data: bytes, encryption: Encryption) -> bytes: :param encryption: Information about how to decrypt :returns: Decrypted data """ - if isinstance(encryption, AESCTREncryption): - cipher = AES.new( - key = encryption.key, - mode = AES.MODE_CTR, - nonce = encryption.nonce, - initial_value = encryption.initial_value - ) - return cipher.decrypt(data) - if isinstance(encryption, AESEncryption): - cipher = AES.new(encryption.key, AES.MODE_CBC, encryption.iv) - return cipher.decrypt(data) - if isinstance(encryption, XOrEncryption): - key_length = len(encryption.key) - decoded = [] - for i in range(0, len(data)): - decoded.append(data[i] ^ encryption.key[i % key_length]) - return bytes(decoded) - raise NotImplemented + return encryption.decrypt(data)