diff --git a/connection/connections/plaid_client_v2.py b/connection/connections/plaid_client_v2.py index 96ef0f6..fd92c75 100755 --- a/connection/connections/plaid_client_v2.py +++ b/connection/connections/plaid_client_v2.py @@ -7,6 +7,9 @@ import plaid from plaid.api import plaid_api from plaid.model.link_token_create_request import LinkTokenCreateRequest from plaid.model.link_token_create_request_user import LinkTokenCreateRequestUser +from plaid.model.item_public_token_exchange_request import ItemPublicTokenExchangeRequest +from plaid.model.transactions_sync_request import TransactionsSyncRequest +from plaid.model.accounts_get_request import AccountsGetRequest from plaid.model.products import Products from plaid.model.country_code import CountryCode @@ -23,7 +26,8 @@ def format_error(e): class Connection(AbstractConnectionClient): def __init__(self, credentials, account_id=None): - self.credentials = credentials.dict() + self.credentials = credentials + self.account_id = account_id # Fill in your Plaid API keys - # https://dashboard.plaid.com/account/keys @@ -51,9 +55,29 @@ class Connection(AbstractConnectionClient): 'secret': self.PLAID_SECRET, } ) - api_client = plaid.ApiClient(configuration) - client = plaid_api.PlaidApi(api_client) + self.api_client = plaid.ApiClient(configuration) + self.client = plaid_api.PlaidApi(self.api_client) + # # Create a link_token for the given user + # request = LinkTokenCreateRequest( + # products=[Products("auth")], + # client_name="Qrtr Plaid", + # country_codes=[CountryCode('US')], + # #redirect_uri='https://domainname.com/oauth-page.html', + # language='en', + # webhook='https://webhook.example.com', + # user=LinkTokenCreateRequestUser( + # client_user_id=self.account_id + # ) + # ) + # response = client.link_token_create(request) + # resp_dict = response.to_dict() + # resp_dict['expiration'] = resp_dict['expiration'].strftime('%s') + + # self.credentials.update(resp_dict) + # return self.credentials + + def generate_auth_request(self): # Create a link_token for the given user request = LinkTokenCreateRequest( products=[Products("auth")], @@ -63,13 +87,14 @@ class Connection(AbstractConnectionClient): language='en', webhook='https://webhook.example.com', user=LinkTokenCreateRequestUser( - client_user_id=account_id + client_user_id=self.account_id ) ) - response = client.link_token_create(request) + response = self.client.link_token_create(request) + resp_dict = response.to_dict() + resp_dict['expiration'] = resp_dict['expiration'].strftime('%s') - self.credentials.update(response.to_dict()) - # return self.credentials + self.credentials.update(resp_dict) def get_auth_token(self, public_token): try: @@ -84,11 +109,12 @@ class Connection(AbstractConnectionClient): return format_error(e) access_token = exchange_response['access_token'] item_id = exchange_response['item_id'] + self.credentials.update({"access_token":access_token, "item_id":item_id}) return {"access_token":access_token, "item_id":item_id} def get_accounts(self, auth_token=None): if not auth_token: - auth_token = self.credentials.get('auth_token') + auth_token = self.credentials.get('access_token') if not auth_token: raise Exception("Missing Auth Token") try: @@ -97,7 +123,7 @@ class Connection(AbstractConnectionClient): except Exception as e: print(e) accounts = None - return accounts + return accounts.get('accounts') def get_transactions( self, @@ -105,22 +131,20 @@ class Connection(AbstractConnectionClient): end_date=None, auth_token=None): if not auth_token: - auth_token = self.credentials.get('auth_token') - if not auth_token: - raise Exception("Missing Auth Token") - if not start_date: - start_date = '{:%Y-%m-%d}'.format( - datetime.datetime.now() + datetime.timedelta(-30)) - if not end_date: - end_date = '{:%Y-%m-%d}'.format(datetime.datetime.now()) - try: - transactions_req = TransactionsGetRequest( + auth_token = self.credentials.get('access_token') + request = TransactionsSyncRequest( + access_token=auth_token, + ) + response = self.client.transactions_sync(request) + transactions = response['added'] + + # the transactions in the response are paginated, so make multiple calls while incrementing the cursor to + # retrieve all transactions + while (response['has_more']): + request = TransactionsSyncRequest( access_token=auth_token, - start_date=start_date, - end_date=end_date + cursor=response['next_cursor'] ) - transactions_resp = self.client.transactions_get( - transactions_req) - except plaid.errors.PlaidError as e: - return format_error(e) - return transactions_resp.get("transactions") + response = self.client.transactions_sync(request) + transactions += response['added'] + return transactions diff --git a/connection/views.py b/connection/views.py index a79ebe5..6c44e19 100644 --- a/connection/views.py +++ b/connection/views.py @@ -33,6 +33,48 @@ class ConnectionViewSet(viewsets.ModelViewSet): 'delete', 'options'] + + @action(detail=False, methods=['post'], url_path='plaid/exchange_public_token') + def exchange_public_token(self, request): + print(f"REQUEST: {request.data}") + name = request.data.get("name", "dummyName") + account_id = request.data.get("account") + public_token = request.data.get("public_token") + user = request.user + accounts = (Account.objects.filter(pk=account_id, owner=user) | + Account.objects.filter(pk=account_id, + admin_users__in=[user])) + if not accounts: + return Response( + status=status.HTTP_400_BAD_REQUEST, + data="ERROR: Account ID not found") + else: + print(f"Account Found: {accounts[0]}") + account = accounts[0] + print(request) + plaid_conn = importlib.import_module(f"connection.connections.plaid_client_v2") + conn_type = ConnectionType.objects.get(name="Plaid") + try: + plaid_client = plaid_conn.Connection(request.data.dict(), account_id=account_id) + token = plaid_client.get_auth_token(public_token) + except ValueError: + return Response(status=status.HTTP_503, + data="ERROR: Invalid public_token") + with transaction.atomic(): + conn, created = Connection.objects \ + .get_or_create(name=name, type=conn_type, + defaults={ + "credentials": request.data, + "account": account + }) + conn.credentials = plaid_client.credentials + print(f"CREDS: {plaid_client.credentials}") + conn.save() + return Response(plaid_client.get_accounts()) + + + + @action(detail=False, methods=['post'], url_path='plaid') def authenticate(self, request): print(request.data) @@ -66,7 +108,8 @@ class ConnectionViewSet(viewsets.ModelViewSet): plaid_conn = importlib.import_module(f"connection.connections.plaid_client_v2") conn_type = ConnectionType.objects.get(name="Plaid") try: - plaid_client = plaid_conn.Connection(request.data, account_id=account_id) + plaid_client = plaid_conn.Connection(request.data.dict(), account_id=account_id) + plaid_client.generate_auth_request() except ValueError: return Response(status=status.HTTP_503, data="ERROR: Invalid public_token") @@ -82,6 +125,7 @@ class ConnectionViewSet(viewsets.ModelViewSet): "account": account }) conn.credentials = plaid_client.credentials + print(f"CREDS: {plaid_client.credentials}") conn.save() return Response(plaid_client.credentials)