-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathverify_installation.py
More file actions
170 lines (132 loc) · 4.67 KB
/
verify_installation.py
File metadata and controls
170 lines (132 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Installation verification script for TensorTalk.
This script checks that all required dependencies are installed and working.
Run this after installing the package to verify everything is set up correctly.
Usage:
python verify_installation.py
"""
import sys
def check_imports():
"""Check if all required packages can be imported."""
print("Checking package imports...")
print("-" * 60)
packages = {
'torch': 'PyTorch',
'torchaudio': 'TorchAudio',
'transformers': 'Transformers',
'gtts': 'gTTS',
'numpy': 'NumPy',
}
failed = []
for package, name in packages.items():
try:
__import__(package)
print(f"✓ {name:20s} ... OK")
except ImportError as e:
print(f"✗ {name:20s} ... FAILED")
failed.append((name, package))
return failed
def check_cuda():
"""Check CUDA availability."""
print("\nChecking CUDA...")
print("-" * 60)
try:
import torch
cuda_available = torch.cuda.is_available()
if cuda_available:
print(f"✓ CUDA is available")
print(f" Device: {torch.cuda.get_device_name(0)}")
print(f" CUDA Version: {torch.version.cuda}")
else:
print("⚠ CUDA not available (CPU only)")
print(" This is okay, but processing will be slower.")
return True
except Exception as e:
print(f"✗ Error checking CUDA: {e}")
return False
def check_models():
"""Check if models can be loaded."""
print("\nChecking model loading...")
print("-" * 60)
try:
from transformers import WavLMModel
print("✓ WavLM model ... Loading (this may take a moment)")
# Try to load WavLM (this will download if not cached)
try:
model = WavLMModel.from_pretrained("microsoft/wavlm-large")
print("✓ WavLM model ... Loaded successfully")
del model # Free memory
return True
except Exception as e:
print(f"✗ WavLM model ... Failed to load: {e}")
return False
except Exception as e:
print(f"✗ Error: {e}")
return False
def check_tensortalk():
"""Check if TensorTalk package can be imported."""
print("\nChecking TensorTalk package...")
print("-" * 60)
try:
import sys
sys.path.append('.')
from src import TensorTalkPipeline, SSLEncoder, KNNMatcher
print("✓ TensorTalkPipeline ... OK")
print("✓ SSLEncoder ... OK")
print("✓ KNNMatcher ... OK")
return True
except ImportError as e:
print(f"✗ Failed to import TensorTalk: {e}")
print("\nMake sure you're running this from the TensorTalk root directory.")
return False
def main():
"""Run all verification checks."""
print("=" * 60)
print("TensorTalk Installation Verification")
print("=" * 60)
print()
results = []
# Check imports
failed_imports = check_imports()
results.append(("Package Imports", len(failed_imports) == 0))
# Check CUDA
cuda_ok = check_cuda()
results.append(("CUDA Check", cuda_ok))
# Check TensorTalk
tensortalk_ok = check_tensortalk()
results.append(("TensorTalk Import", tensortalk_ok))
# Check model loading (optional, can be slow)
print("\n" + "=" * 60)
response = input("Load WavLM model to test? (y/n, this may take time): ").lower()
if response == 'y':
model_ok = check_models()
results.append(("Model Loading", model_ok))
# Summary
print("\n" + "=" * 60)
print("Verification Summary")
print("=" * 60)
all_ok = True
for name, status in results:
symbol = "✓" if status else "✗"
print(f"{symbol} {name}")
if not status:
all_ok = False
print("=" * 60)
if failed_imports:
print("\n⚠ Missing packages:")
for name, package in failed_imports:
print(f" - {name} (pip install {package})")
print("\nInstall missing packages with:")
print(" pip install -r requirements.txt")
if all_ok:
print("\n✓ All checks passed! TensorTalk is ready to use.")
print("\nNext steps:")
print(" 1. Check out notebooks/demo.ipynb for a tutorial")
print(" 2. Run simple_example.py for a quick test")
print(" 3. Read the paper: TensorTalk_Paper.pdf")
return 0
else:
print("\n✗ Some checks failed. Please fix the issues above.")
return 1
if __name__ == "__main__":
sys.exit(main())