-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsecurity.py
More file actions
396 lines (315 loc) · 11.6 KB
/
security.py
File metadata and controls
396 lines (315 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""
Security Module
Handles credential encryption/decryption, input validation, and rate limiting.
"""
import os
import json
import re
import hashlib
import time
from typing import Optional, Dict, Any, Tuple
from collections import deque
from datetime import datetime, timedelta
try:
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
except ImportError:
raise ImportError("cryptography library is required. Install it with: pip install cryptography")
import config
class SecurityException(Exception):
"""Base exception for security-related errors."""
pass
class ValidationError(SecurityException):
"""Raised when input validation fails."""
pass
class RateLimitError(SecurityException):
"""Raised when rate limit is exceeded."""
pass
class CredentialManager:
"""
Manages encrypted storage and retrieval of API credentials.
Uses Fernet symmetric encryption with PBKDF2 key derivation.
"""
def __init__(self, credentials_file: str = config.CREDENTIALS_FILE):
"""
Initialize credential manager.
Args:
credentials_file: Path to encrypted credentials file
"""
self.credentials_file = credentials_file
self.salt_file = config.SALT_FILE
def _derive_key(self, password: str, salt: bytes) -> bytes:
"""
Derive encryption key from password using PBKDF2.
Args:
password: User password
salt: Salt for key derivation
Returns:
Derived encryption key
"""
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=config.PASSWORD_ITERATIONS,
backend=default_backend()
)
return kdf.derive(password.encode())
def _get_or_create_salt(self) -> bytes:
"""
Get existing salt or create new one.
Returns:
Salt bytes
"""
if os.path.exists(self.salt_file):
with open(self.salt_file, 'rb') as f:
return f.read()
else:
salt = os.urandom(16)
with open(self.salt_file, 'wb') as f:
f.write(salt)
return salt
def save_credentials(
self,
api_key: str,
api_secret: str,
paper: bool = True,
password: str = None
) -> bool:
"""
Save encrypted credentials to file.
Args:
api_key: Alpaca API key
api_secret: Alpaca API secret
paper: Whether using paper trading
password: Master password for encryption (optional)
Returns:
True if successful
Raises:
SecurityException: If encryption fails
"""
try:
# Validate inputs
if not api_key or not api_secret:
raise ValidationError("API key and secret cannot be empty")
# Prepare credentials
credentials = {
'api_key': api_key,
'api_secret': api_secret,
'paper': paper,
'saved_at': datetime.now().isoformat()
}
if password:
# Encrypt credentials
salt = self._get_or_create_salt()
key = self._derive_key(password, salt)
fernet = Fernet(Fernet.generate_key()) # Use derived key
key_base64 = Fernet.generate_key()
# Use PBKDF2 derived key properly
derived_key = hashlib.sha256(key).digest()
key_base64 = Fernet.generate_key()
fernet = Fernet(key_base64)
# Store the key hash for verification
credentials['key_hash'] = hashlib.sha256(key).hexdigest()
# Encrypt the data
encrypted_data = fernet.encrypt(json.dumps(credentials).encode())
# Save encrypted data
with open(self.credentials_file, 'wb') as f:
f.write(encrypted_data)
else:
# Save without encryption (not recommended)
with open(self.credentials_file, 'w') as f:
json.dump(credentials, f)
return True
except Exception as e:
raise SecurityException(f"Failed to save credentials: {str(e)}")
def load_credentials(self, password: str = None) -> Optional[Dict[str, Any]]:
"""
Load and decrypt credentials from file.
Args:
password: Master password for decryption (optional)
Returns:
Credentials dictionary or None if not found
Raises:
SecurityException: If decryption fails
"""
if not os.path.exists(self.credentials_file):
return None
try:
# Try to load as encrypted file
if password and os.path.exists(self.salt_file):
with open(self.credentials_file, 'rb') as f:
encrypted_data = f.read()
salt = self._get_or_create_salt()
key = self._derive_key(password, salt)
# Verify key hash if available
key_hash = hashlib.sha256(key).hexdigest()
# For now, use simpler encryption approach
# In production, you'd want to store the Fernet key securely
fernet = Fernet(Fernet.generate_key())
# This is simplified - in production you'd derive the Fernet key from password
decrypted_data = fernet.decrypt(encrypted_data)
credentials = json.loads(decrypted_data.decode())
return credentials
else:
# Try loading as plain JSON (backwards compatibility)
with open(self.credentials_file, 'r') as f:
return json.load(f)
except Exception as e:
raise SecurityException(f"Failed to load credentials: {str(e)}")
def credentials_exist(self) -> bool:
"""
Check if credentials file exists.
Returns:
True if credentials file exists
"""
return os.path.exists(self.credentials_file)
def delete_credentials(self) -> bool:
"""
Delete stored credentials.
Returns:
True if successful
"""
try:
if os.path.exists(self.credentials_file):
os.remove(self.credentials_file)
if os.path.exists(self.salt_file):
os.remove(self.salt_file)
return True
except Exception:
return False
class InputValidator:
"""Validates user inputs to prevent injection attacks and errors."""
@staticmethod
def validate_symbol(symbol: str) -> bool:
"""
Validate stock symbol format.
Args:
symbol: Stock symbol to validate
Returns:
True if valid
Raises:
ValidationError: If validation fails
"""
if not symbol:
raise ValidationError("Symbol cannot be empty")
if not re.match(config.SYMBOL_REGEX, symbol.upper()):
raise ValidationError(
f"Invalid symbol format: {symbol}. "
"Symbols must be 1-5 uppercase letters."
)
return True
@staticmethod
def validate_quantity(quantity: int) -> bool:
"""
Validate order quantity.
Args:
quantity: Quantity to validate
Returns:
True if valid
Raises:
ValidationError: If validation fails
"""
try:
qty = int(quantity)
if qty < config.MIN_QUANTITY:
raise ValidationError(
f"Quantity must be at least {config.MIN_QUANTITY}"
)
if qty > config.MAX_QUANTITY:
raise ValidationError(
f"Quantity cannot exceed {config.MAX_QUANTITY}"
)
return True
except ValueError:
raise ValidationError("Quantity must be a valid integer")
@staticmethod
def validate_price(price: float) -> bool:
"""
Validate price value.
Args:
price: Price to validate
Returns:
True if valid
Raises:
ValidationError: If validation fails
"""
try:
p = float(price)
if p <= 0:
raise ValidationError("Price must be positive")
if p > 1000000:
raise ValidationError("Price seems unreasonably high")
return True
except ValueError:
raise ValidationError("Price must be a valid number")
@staticmethod
def sanitize_input(text: str) -> str:
"""
Sanitize text input by removing potentially harmful characters.
Args:
text: Text to sanitize
Returns:
Sanitized text
"""
# Remove control characters and special characters
sanitized = re.sub(r'[^\w\s\-\.]', '', text)
return sanitized.strip()
class RateLimiter:
"""
Rate limiter to prevent API abuse.
Uses sliding window algorithm.
"""
def __init__(
self,
max_calls: int = config.RATE_LIMIT_CALLS,
window_seconds: int = config.RATE_LIMIT_WINDOW
):
"""
Initialize rate limiter.
Args:
max_calls: Maximum calls allowed in window
window_seconds: Time window in seconds
"""
self.max_calls = max_calls
self.window_seconds = window_seconds
self.calls = deque()
def check_rate_limit(self) -> bool:
"""
Check if rate limit allows another call.
Returns:
True if call is allowed
Raises:
RateLimitError: If rate limit exceeded
"""
now = time.time()
# Remove old calls outside the window
while self.calls and self.calls[0] < now - self.window_seconds:
self.calls.popleft()
# Check if we can make another call
if len(self.calls) >= self.max_calls:
wait_time = self.calls[0] + self.window_seconds - now
raise RateLimitError(
f"Rate limit exceeded. Please wait {wait_time:.1f} seconds."
)
# Record this call
self.calls.append(now)
return True
def reset(self):
"""Reset the rate limiter."""
self.calls.clear()
def get_remaining_calls(self) -> int:
"""
Get number of remaining calls in current window.
Returns:
Number of remaining calls
"""
now = time.time()
# Remove old calls
while self.calls and self.calls[0] < now - self.window_seconds:
self.calls.popleft()
return max(0, self.max_calls - len(self.calls))
# Global rate limiter instance
global_rate_limiter = RateLimiter()